1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 #include "OnnxParserTestUtils.hpp" 9 10 TEST_SUITE("OnnxParser_Gemm") 11 { 12 13 struct GemmFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 14 { GemmFixtureGemmFixture15 GemmFixture(const std::string& alpha, 16 const std::string& beta, 17 const std::string& transA, 18 const std::string& transB, 19 const std::vector<int>& inputAShape, 20 const std::vector<int>& inputBShape, 21 const std::vector<int>& inputCShape, 22 const std::vector<int>& outputShape) 23 { 24 m_Prototext = R"( 25 ir_version: 8 26 producer_name: "onnx-example" 27 graph { 28 node { 29 input: "A" 30 input: "B" 31 input: "C" 32 output: "Output" 33 op_type: "Gemm" 34 attribute { 35 name: "alpha" 36 f: )" + alpha + R"( 37 type: FLOAT 38 } 39 attribute { 40 name: "beta" 41 f: )" + beta + R"( 42 type: FLOAT 43 } 44 attribute { 45 name: "transA" 46 i: )" + transA + R"( 47 type: INT 48 } 49 attribute { 50 name: "transB" 51 i: )" + transB + R"( 52 type: INT 53 } 54 } 55 name: "gem-model" 56 input { 57 name: "A" 58 type { 59 tensor_type { 60 elem_type: 1 61 shape { 62 )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"( 63 } 64 } 65 } 66 } 67 input { 68 name: "B" 69 type { 70 tensor_type { 71 elem_type: 1 72 shape { 73 )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"( 74 } 75 } 76 } 77 } 78 input { 79 name: "C" 80 type { 81 tensor_type { 82 elem_type: 1 83 shape { 84 )" + armnnUtils::ConstructTensorShapeString(inputCShape) + R"( 85 } 86 } 87 } 88 } 89 output { 90 name: "Output" 91 type { 92 tensor_type { 93 elem_type: 1 94 shape { 95 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 96 } 97 } 98 } 99 } 100 })"; 101 } 102 }; 103 104 struct GemmAllAttributesFixture : GemmFixture 105 { GemmAllAttributesFixtureGemmAllAttributesFixture106 GemmAllAttributesFixture() : GemmFixture("0.25", "0.35", "1", "1", { 4, 3 }, { 5, 4 }, { 5 }, { 3, 5 }) 107 { 108 Setup(); 109 } 110 }; 111 112 struct GemmSimpleFixture : GemmFixture 113 { GemmSimpleFixtureGemmSimpleFixture114 GemmSimpleFixture() : GemmFixture("1", "1", "0", "0", { 3, 4 }, { 4, 5 }, { 5 }, { 3, 5 }) 115 { 116 Setup(); 117 } 118 }; 119 120 struct GemmTransAFixture : GemmFixture 121 { GemmTransAFixtureGemmTransAFixture122 GemmTransAFixture() : GemmFixture("1", "1", "1", "0", { 4, 3 }, { 4, 5 }, { 5 }, { 3, 5 }) 123 { 124 Setup(); 125 } 126 }; 127 128 struct GemmTransBFixture : GemmFixture 129 { GemmTransBFixtureGemmTransBFixture130 GemmTransBFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 5 }, { 3, 5 }) 131 { 132 Setup(); 133 } 134 }; 135 136 struct GemmParseExceptionFixture : GemmFixture 137 { GemmParseExceptionFixtureGemmParseExceptionFixture138 GemmParseExceptionFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }, { 3, 5 }) {} 139 }; 140 141 TEST_CASE_FIXTURE(GemmAllAttributesFixture, "GemmTest") 142 { 143 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 144 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 145 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 146 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 147 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 148 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, 149 {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, 150 {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f, 151 12.535f, 38.57f, 64.605f, 90.64f, 116.675f, 152 10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}}); 153 } 154 155 TEST_CASE_FIXTURE(GemmSimpleFixture, "GemmSimpleTest") 156 { 157 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 158 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 159 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 160 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 161 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 162 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, 163 {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, 164 {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f, 165 196.1f, 222.2f, 248.3f, 274.4f, 300.5f, 166 60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}}); 167 } 168 169 TEST_CASE_FIXTURE(GemmTransAFixture, "GemmTransposeATest") 170 { 171 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 172 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 173 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 174 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 175 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 176 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, 177 {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, 178 {{"Output", { 180.1f, 210.2f, 240.3f, 270.4f, 300.5f, 179 146.1f, 172.2f, 198.3f, 224.4f, 250.5f, 180 112.1f, 134.2f, 156.3f, 178.4f, 200.5f }}}); 181 } 182 183 TEST_CASE_FIXTURE(GemmTransBFixture, "GemmTransposeBTest") 184 { 185 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 186 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 187 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 188 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 189 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 190 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, 191 {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, 192 {{"Output", { 100.1f, 268.2f, 436.3f, 604.4f, 772.5f, 193 60.1f, 164.2f, 268.3f, 372.4f, 476.5f, 194 20.1f, 60.2f, 100.3f, 140.4f, 180.5f }}}); 195 } 196 197 TEST_CASE_FIXTURE(GemmParseExceptionFixture, "GemmParseExceptionTest") 198 { 199 // ParseException because Input C is non-constant and has 2 dimension (should be 1 dimension) 200 CHECK_THROWS_AS(Setup(), armnn::ParseException); 201 } 202 203 struct GemmConstantFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 204 { GemmConstantFixtureGemmConstantFixture205 GemmConstantFixture() 206 { 207 m_Prototext = R"( 208 ir_version: 8 209 producer_name: "onnx-example" 210 graph { 211 node { 212 input: "A" 213 input: "B" 214 input: "C" 215 output: "Output" 216 op_type: "Gemm" 217 attribute { 218 name: "alpha" 219 f: 0.25 220 type: FLOAT 221 } 222 attribute { 223 name: "beta" 224 f: 0.35 225 type: FLOAT 226 } 227 attribute { 228 name: "transA" 229 i: 1 230 type: INT 231 } 232 attribute { 233 name: "transB" 234 i: 1 235 type: INT 236 } 237 } 238 name: "gem-model" 239 initializer { 240 dims: 5 241 dims: 4 242 data_type: 1 243 float_data: 1.0 244 float_data: 2.0 245 float_data: 3.0 246 float_data: 4.0 247 float_data: 5.0 248 float_data: 6.0 249 float_data: 7.0 250 float_data: 8.0 251 float_data: 9.0 252 float_data: 10.0 253 float_data: 11.0 254 float_data: 12.0 255 float_data: 13.0 256 float_data: 14.0 257 float_data: 15.0 258 float_data: 16.0 259 float_data: 17.0 260 float_data: 18.0 261 float_data: 19.0 262 float_data: 20.0 263 name: "B" 264 } 265 initializer { 266 dims: 1 267 dims: 5 268 data_type: 1 269 float_data: 0.1 270 float_data: 0.2 271 float_data: 0.3 272 float_data: 0.4 273 float_data: 0.5 274 name: "C" 275 } 276 input { 277 name: "A" 278 type { 279 tensor_type { 280 elem_type: 1 281 shape { 282 dim { 283 dim_value: 4 284 } 285 dim { 286 dim_value: 3 287 } 288 } 289 } 290 } 291 } 292 output { 293 name: "Output" 294 type { 295 tensor_type { 296 elem_type: 1 297 shape { 298 dim { 299 dim_value: 3 300 } 301 dim { 302 dim_value: 5 303 } 304 } 305 } 306 } 307 } 308 })"; 309 Setup(); 310 } 311 }; 312 313 TEST_CASE_FIXTURE(GemmConstantFixture, "GemmConstantTest") 314 { 315 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 316 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}}, 317 {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f, 318 12.535f, 38.57f, 64.605f, 90.64f, 116.675f, 319 10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}}); 320 } 321 322 struct GemmConstantSimpleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 323 { GemmConstantSimpleFixtureGemmConstantSimpleFixture324 GemmConstantSimpleFixture() 325 { 326 m_Prototext = R"( 327 ir_version: 8 328 producer_name: "onnx-example" 329 graph { 330 node { 331 input: "A" 332 input: "B" 333 input: "C" 334 output: "Output" 335 op_type: "Gemm" 336 attribute { 337 name: "alpha" 338 f: 1 339 type: FLOAT 340 } 341 attribute { 342 name: "beta" 343 f: 1 344 type: FLOAT 345 } 346 attribute { 347 name: "transA" 348 i: 0 349 type: INT 350 } 351 attribute { 352 name: "transB" 353 i: 0 354 type: INT 355 } 356 } 357 name: "gem-model" 358 initializer { 359 dims: 4 360 dims: 5 361 data_type: 1 362 float_data: 1.0 363 float_data: 2.0 364 float_data: 3.0 365 float_data: 4.0 366 float_data: 5.0 367 float_data: 6.0 368 float_data: 7.0 369 float_data: 8.0 370 float_data: 9.0 371 float_data: 10.0 372 float_data: 11.0 373 float_data: 12.0 374 float_data: 13.0 375 float_data: 14.0 376 float_data: 15.0 377 float_data: 16.0 378 float_data: 17.0 379 float_data: 18.0 380 float_data: 19.0 381 float_data: 20.0 382 name: "B" 383 } 384 initializer { 385 dims: 1 386 dims: 5 387 data_type: 1 388 float_data: 0.1 389 float_data: 0.2 390 float_data: 0.3 391 float_data: 0.4 392 float_data: 0.5 393 name: "C" 394 } 395 input { 396 name: "A" 397 type { 398 tensor_type { 399 elem_type: 1 400 shape { 401 dim { 402 dim_value: 3 403 } 404 dim { 405 dim_value: 4 406 } 407 } 408 } 409 } 410 } 411 output { 412 name: "Output" 413 type { 414 tensor_type { 415 elem_type: 1 416 shape { 417 dim { 418 dim_value: 3 419 } 420 dim { 421 dim_value: 5 422 } 423 } 424 } 425 } 426 } 427 })"; 428 Setup(); 429 } 430 }; 431 432 TEST_CASE_FIXTURE(GemmConstantSimpleFixture, "GemmConstantSimpleTest") 433 { 434 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 435 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}}, 436 {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f, 437 196.1f, 222.2f, 248.3f, 274.4f, 300.5f, 438 60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}}); 439 } 440 441 struct GemmABFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 442 { GemmABFixtureGemmABFixture443 GemmABFixture(const std::string& alpha, 444 const std::string& beta, 445 const std::string& transA, 446 const std::string& transB, 447 const std::vector<int>& inputAShape, 448 const std::vector<int>& inputBShape, 449 const std::vector<int>& outputShape) 450 { 451 m_Prototext = R"( 452 ir_version: 8 453 producer_name: "onnx-example" 454 graph { 455 node { 456 input: "A" 457 input: "B" 458 output: "Output" 459 op_type: "Gemm" 460 attribute { 461 name: "alpha" 462 f: )" + alpha + R"( 463 type: FLOAT 464 } 465 attribute { 466 name: "beta" 467 f: )" + beta + R"( 468 type: FLOAT 469 } 470 attribute { 471 name: "transA" 472 i: )" + transA + R"( 473 type: INT 474 } 475 attribute { 476 name: "transB" 477 i: )" + transB + R"( 478 type: INT 479 } 480 } 481 name: "gem-model" 482 input { 483 name: "A" 484 type { 485 tensor_type { 486 elem_type: 1 487 shape { 488 )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"( 489 } 490 } 491 } 492 } 493 input { 494 name: "B" 495 type { 496 tensor_type { 497 elem_type: 1 498 shape { 499 )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"( 500 } 501 } 502 } 503 } 504 output { 505 name: "Output" 506 type { 507 tensor_type { 508 elem_type: 1 509 shape { 510 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 511 } 512 } 513 } 514 } 515 })"; 516 Setup(); 517 } 518 }; 519 520 struct GemmAlphaTransAFixture : GemmABFixture 521 { GemmAlphaTransAFixtureGemmAlphaTransAFixture522 GemmAlphaTransAFixture() : GemmABFixture("0.25", "0.35", "1", "0", { 4, 3 }, { 4, 5 }, { 3, 5 }) {} 523 }; 524 525 struct GemmAlphaTransBFixture : GemmABFixture 526 { GemmAlphaTransBFixtureGemmAlphaTransBFixture527 GemmAlphaTransBFixture() : GemmABFixture("0.25", "0.35", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }) {} 528 }; 529 530 TEST_CASE_FIXTURE(GemmAlphaTransAFixture, "GemmAlphaTransATest") 531 { 532 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 533 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 534 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 535 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 536 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 537 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}}, 538 {{"Output", { 45.0f, 52.5f, 60.0f, 67.5f, 75.0f, 539 36.5f, 43.0f, 49.5f, 56.0f, 62.5f, 540 28.0f, 33.5f, 39.0f, 44.5f, 50.0f }}}); 541 } 542 543 TEST_CASE_FIXTURE(GemmAlphaTransBFixture, "GemmAlphaTransBTest") 544 { 545 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 546 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 547 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 548 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 549 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 550 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}}, 551 {{"Output", { 25.0f, 67.0f, 109.0f, 151.0f, 193.0f, 552 15.0f, 41.0f, 67.0f, 93.0f, 119.0f, 553 5.0f, 15.0f, 25.0f, 35.0f, 45.0f }}}); 554 } 555 556 } 557