1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 9 TEST_SUITE("OnnxParser_Conv2D") 10 { 11 struct SimpleConv2DFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 12 { SimpleConv2DFixtureSimpleConv2DFixture13 SimpleConv2DFixture() 14 { 15 m_Prototext = R"( 16 ir_version: 3 17 producer_name: "CNTK" 18 producer_version: "2.5.1" 19 domain: "ai.cntk" 20 model_version: 1 21 graph { 22 name: "CNTKGraph" 23 input { 24 name: "Input" 25 type { 26 tensor_type { 27 elem_type: 1 28 shape { 29 dim { 30 dim_value: 1 31 } 32 dim { 33 dim_value: 1 34 } 35 dim { 36 dim_value: 3 37 } 38 dim { 39 dim_value: 3 40 } 41 } 42 } 43 } 44 } 45 input { 46 name: "Weight" 47 type { 48 tensor_type { 49 elem_type: 1 50 shape { 51 dim { 52 dim_value: 1 53 } 54 dim { 55 dim_value: 1 56 } 57 dim { 58 dim_value: 3 59 } 60 dim { 61 dim_value: 3 62 } 63 } 64 } 65 } 66 } 67 initializer { 68 dims: 1 69 dims: 1 70 dims: 3 71 dims: 3 72 data_type: 1 73 float_data: 2 74 float_data: 1 75 float_data: 0 76 float_data: 6 77 float_data: 2 78 float_data: 1 79 float_data: 4 80 float_data: 1 81 float_data: 2 82 name: "Weight" 83 } 84 node { 85 input: "Input" 86 input: "Weight" 87 output: "Output" 88 name: "Convolution" 89 op_type: "Conv" 90 attribute { 91 name: "kernel_shape" 92 ints: 3 93 ints: 3 94 type: INTS 95 } 96 attribute { 97 name: "strides" 98 ints: 1 99 ints: 1 100 type: INTS 101 } 102 attribute { 103 name: "auto_pad" 104 s: "VALID" 105 type: STRING 106 } 107 attribute { 108 name: "group" 109 i: 1 110 type: INT 111 } 112 attribute { 113 name: "dilations" 114 ints: 1 115 ints: 1 116 type: INTS 117 } 118 doc_string: "" 119 domain: "" 120 } 121 output { 122 name: "Output" 123 type { 124 tensor_type { 125 elem_type: 1 126 shape { 127 dim { 128 dim_value: 1 129 } 130 dim { 131 dim_value: 1 132 } 133 dim { 134 dim_value: 1 135 } 136 dim { 137 dim_value: 1 138 } 139 } 140 } 141 } 142 } 143 } 144 opset_import { 145 version: 7 146 })"; 147 Setup(); 148 } 149 }; 150 151 struct Conv2DWithBiasesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 152 { Conv2DWithBiasesFixtureConv2DWithBiasesFixture153 Conv2DWithBiasesFixture() { 154 m_Prototext = R"( 155 ir_version: 3 156 producer_name: "CNTK" 157 producer_version: "2.5.1" 158 domain: "ai.cntk" 159 model_version: 1 160 graph { 161 name: "CNTKGraph" 162 input { 163 name: "Input" 164 type { 165 tensor_type { 166 elem_type: 1 167 shape { 168 dim { 169 dim_value: 1 170 } 171 dim { 172 dim_value: 1 173 } 174 dim { 175 dim_value: 2 176 } 177 dim { 178 dim_value: 2 179 } 180 } 181 } 182 } 183 } 184 input { 185 name: "Weight" 186 type { 187 tensor_type { 188 elem_type: 1 189 shape { 190 dim { 191 dim_value: 1 192 } 193 dim { 194 dim_value: 1 195 } 196 dim { 197 dim_value: 2 198 } 199 dim { 200 dim_value: 2 201 } 202 } 203 } 204 } 205 } 206 initializer { 207 dims: 1 208 dims: 1 209 dims: 2 210 dims: 2 211 data_type: 1 212 float_data: 2 213 float_data: 1 214 float_data: 0 215 float_data: 6 216 name: "Weight" 217 } 218 input { 219 name: "Bias" 220 type { 221 tensor_type { 222 elem_type: 1 223 shape { 224 dim { 225 dim_value: 4 226 } 227 } 228 } 229 } 230 } 231 initializer { 232 dims: 4 233 data_type: 1 234 float_data: 10 235 float_data: 0 236 float_data: 0 237 float_data: 0 238 name: "Bias" 239 } 240 node { 241 input: "Input" 242 input: "Weight" 243 input: "Bias" 244 output: "Output" 245 name: "Convolution" 246 op_type: "Conv" 247 attribute { 248 name: "kernel_shape" 249 ints: 2 250 ints: 2 251 type: INTS 252 } 253 attribute { 254 name: "strides" 255 ints: 1 256 ints: 1 257 type: INTS 258 } 259 attribute { 260 name: "auto_pad" 261 s: "SAME_UPPER" 262 type: STRING 263 } 264 attribute { 265 name: "group" 266 i: 1 267 type: INT 268 } 269 attribute { 270 name: "dilations" 271 ints: 1 272 ints: 1 273 type: INTS 274 } 275 doc_string: "" 276 domain: "" 277 } 278 output { 279 name: "Output" 280 type { 281 tensor_type { 282 elem_type: 1 283 shape { 284 dim { 285 dim_value: 1 286 } 287 dim { 288 dim_value: 1 289 } 290 dim { 291 dim_value: 2 292 } 293 dim { 294 dim_value: 2 295 } 296 } 297 } 298 } 299 } 300 } 301 opset_import { 302 version: 7 303 })"; 304 Setup(); 305 } 306 }; 307 308 309 struct Conv2DDimReducingFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 310 { Conv2DDimReducingFixtureConv2DDimReducingFixture311 Conv2DDimReducingFixture() { 312 m_Prototext = R"( 313 ir_version: 3 314 producer_name: "CNTK" 315 producer_version: "2.5.1" 316 domain: "ai.cntk" 317 model_version: 1 318 graph { 319 name: "CNTKGraph" 320 input { 321 name: "Input" 322 type { 323 tensor_type { 324 elem_type: 1 325 shape { 326 dim { 327 dim_value: 1 328 } 329 dim { 330 dim_value: 3 331 } 332 dim { 333 dim_value: 2 334 } 335 dim { 336 dim_value: 2 337 } 338 } 339 } 340 } 341 } 342 input { 343 name: "Weight" 344 type { 345 tensor_type { 346 elem_type: 1 347 shape { 348 dim { 349 dim_value: 2 350 } 351 dim { 352 dim_value: 3 353 } 354 dim { 355 dim_value: 1 356 } 357 dim { 358 dim_value: 1 359 } 360 } 361 } 362 } 363 } 364 initializer { 365 dims: 2 366 dims: 3 367 dims: 1 368 dims: 1 369 data_type: 1 370 float_data: -1 371 float_data: 2 372 float_data: 0 373 float_data: 1 374 float_data: 0 375 float_data: 0 376 name: "Weight" 377 } 378 node { 379 input: "Input" 380 input: "Weight" 381 output: "Output" 382 name: "Convolution" 383 op_type: "Conv" 384 attribute { 385 name: "kernel_shape" 386 ints: 1 387 ints: 1 388 type: INTS 389 } 390 attribute { 391 name: "strides" 392 ints: 1 393 ints: 1 394 type: INTS 395 } 396 attribute { 397 name: "group" 398 i: 1 399 type: INT 400 } 401 attribute { 402 name: "dilations" 403 ints: 1 404 ints: 1 405 type: INTS 406 } 407 doc_string: "" 408 domain: "" 409 } 410 output { 411 name: "Output" 412 type { 413 tensor_type { 414 elem_type: 1 415 shape { 416 dim { 417 dim_value: 1 418 } 419 dim { 420 dim_value: 2 421 } 422 dim { 423 dim_value: 2 424 } 425 dim { 426 dim_value: 2 427 } 428 } 429 } 430 } 431 } 432 } 433 opset_import { 434 version: 7 435 })"; 436 Setup(); 437 } 438 }; 439 440 struct Conv2DwithDilationFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 441 { Conv2DwithDilationFixtureConv2DwithDilationFixture442 Conv2DwithDilationFixture() 443 { 444 m_Prototext = R"( 445 ir_version: 3 446 producer_name: "CNTK" 447 producer_version: "2.5.1" 448 domain: "ai.cntk" 449 model_version: 1 450 graph { 451 name: "CNTKGraph" 452 input { 453 name: "Input" 454 type { 455 tensor_type { 456 elem_type: 1 457 shape { 458 dim { 459 dim_value: 1 460 } 461 dim { 462 dim_value: 1 463 } 464 dim { 465 dim_value: 6 466 } 467 dim { 468 dim_value: 6 469 } 470 } 471 } 472 } 473 } 474 input { 475 name: "Weight" 476 type { 477 tensor_type { 478 elem_type: 1 479 shape { 480 dim { 481 dim_value: 1 482 } 483 dim { 484 dim_value: 1 485 } 486 dim { 487 dim_value: 3 488 } 489 dim { 490 dim_value: 3 491 } 492 } 493 } 494 } 495 } 496 initializer { 497 dims: 1 498 dims: 1 499 dims: 3 500 dims: 3 501 data_type: 1 502 float_data: 2 503 float_data: 1 504 float_data: 0 505 float_data: 6 506 float_data: 2 507 float_data: 1 508 float_data: 4 509 float_data: 1 510 float_data: 2 511 name: "Weight" 512 } 513 node { 514 input: "Input" 515 input: "Weight" 516 output: "Output" 517 name: "Convolution" 518 op_type: "Conv" 519 attribute { 520 name: "kernel_shape" 521 ints: 3 522 ints: 3 523 type: INTS 524 } 525 attribute { 526 name: "strides" 527 ints: 1 528 ints: 1 529 type: INTS 530 } 531 attribute { 532 name: "auto_pad" 533 s: "VALID" 534 type: STRING 535 } 536 attribute { 537 name: "group" 538 i: 1 539 type: INT 540 } 541 attribute { 542 name: "dilations" 543 ints: 2 544 ints: 2 545 type: INTS 546 } 547 doc_string: "" 548 domain: "" 549 } 550 output { 551 name: "Output" 552 type { 553 tensor_type { 554 elem_type: 1 555 shape { 556 dim { 557 dim_value: 1 558 } 559 dim { 560 dim_value: 1 561 } 562 dim { 563 dim_value: 2 564 } 565 dim { 566 dim_value: 2 567 } 568 } 569 } 570 } 571 } 572 } 573 opset_import { 574 version: 7 575 })"; 576 Setup(); 577 } 578 }; 579 580 TEST_CASE_FIXTURE(SimpleConv2DFixture, "ValidConvTest") 581 { 582 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 583 4.0, 5.0, 6.0, 584 7.0, 8.0, 9.0}}}, 585 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 + 586 4.0 * 6 + 5.0 * 2 + 6.0 * 1 + 587 7.0 * 4 + 8.0 * 1 + 9.0 * 2}}}); 588 } 589 590 TEST_CASE_FIXTURE(Conv2DWithBiasesFixture, "ValidConvWithBiasTest") 591 { 592 RunTest<4>({{"Input", {1.0, 2.0, 593 3.0, 4.0}}}, 594 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 + 4 * 6 + 10, 595 2.0 * 2 + 0 * 1 + 4.0 * 0 + 0 * 6 + 10, 596 3.0 * 2 + 4.0 * 1 + 0 * 0 + 0 * 6 + 10, 597 4.0 * 2 + 0 * 1 + 0 * 0 + 0 * 6 + 10}}}); 598 } 599 600 TEST_CASE_FIXTURE(Conv2DDimReducingFixture, "ValidConvDimReducTest") 601 { 602 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, -1, -2, 3, 4, 1 , 1, 1, 1 }}}, 603 {{"Output", {-1 * 1 + 2 * -1, -1 * 2 + 2 * -2, 604 -1 * 3 + 2 * 3, -1 * 4 + 2 * 4, 605 1, 2, 3, 4}}}); 606 } 607 608 TEST_CASE_FIXTURE(Conv2DwithDilationFixture, "ValidConvWithDilationTest") 609 { 610 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 611 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 612 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 613 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 614 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 615 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}}}, 616 {{"Output", {39.0, 58.0, 153.0, 172.0 }}}); 617 } 618 619 } 620