1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 TEST_SUITE("TensorflowLiteParser_TransposeConv") 9 { 10 struct TransposeConvFixture : public ParserFlatbuffersFixture 11 { TransposeConvFixtureTransposeConvFixture12 explicit TransposeConvFixture(const std::string& inputShape, 13 const std::string& outputShape, 14 const std::string& filterShape, 15 const std::string& filterData, 16 const std::string& strideX, 17 const std::string& strideY, 18 const std::string& dataType) 19 { 20 m_JsonString = R"( 21 { 22 "version": 3, 23 "operator_codes": [ { "builtin_code": "TRANSPOSE_CONV" } ], 24 "subgraphs": [ { 25 "tensors": [ 26 { 27 "shape": [ 4 ], 28 "type": "UINT8", 29 "buffer": 0, 30 "name": "outputShapeTensor", 31 "quantization": { 32 "min": [ 0.0 ], 33 "max": [ 255.0 ], 34 "scale": [ 1.0 ], 35 "zero_point": [ 0 ], 36 } 37 }, 38 { 39 "shape": )" + filterShape + R"(, 40 "type": ")" + dataType + R"(", 41 "buffer": 1, 42 "name": "filterTensor", 43 "quantization": { 44 "min": [ 0.0 ], 45 "max": [ 255.0 ], 46 "scale": [ 1.0 ], 47 "zero_point": [ 0 ], 48 } 49 }, 50 { 51 "shape": )" + inputShape + R"(, 52 "type": ")" + dataType + R"(", 53 "buffer": 2, 54 "name": "inputTensor", 55 "quantization": { 56 "min": [ 0.0 ], 57 "max": [ 255.0 ], 58 "scale": [ 1.0 ], 59 "zero_point": [ 0 ], 60 } 61 }, 62 { 63 "shape": )" + outputShape + R"(, 64 "type": ")" + dataType + R"(", 65 "buffer": 3, 66 "name": "outputTensor", 67 "quantization": { 68 "min": [ 0.0 ], 69 "max": [ 255.0 ], 70 "scale": [ 1.0 ], 71 "zero_point": [ 0 ], 72 } 73 } 74 ], 75 "inputs": [ 2 ], 76 "outputs": [ 3 ], 77 "operators": [ 78 { 79 "opcode_index": 0, 80 "inputs": [ 0, 1, 2 ], 81 "outputs": [ 3 ], 82 "builtin_options_type": "TransposeConvOptions", 83 "builtin_options": { 84 "padding": "VALID", 85 "stride_w": )" + strideX + R"(, 86 "stride_h": )" + strideY + R"( 87 }, 88 "custom_options_format": "FLEXBUFFERS" 89 } 90 ], 91 } ], 92 "buffers" : [ 93 { "data": )" + outputShape + R"( }, 94 { "data": )" + filterData + R"( }, 95 { }, 96 { } 97 ] 98 } 99 )"; 100 SetupSingleInputSingleOutput("inputTensor", "outputTensor"); 101 } 102 }; 103 104 struct SimpleTransposeConvFixture : TransposeConvFixture 105 { SimpleTransposeConvFixtureSimpleTransposeConvFixture106 SimpleTransposeConvFixture() 107 : TransposeConvFixture("[ 1, 2, 2, 1 ]", // inputShape 108 "[ 1, 3, 3, 1 ]", // outputShape 109 "[ 1, 2, 2, 1 ]", // filterShape 110 "[ 0, 1, 2, 4 ]", // filterData 111 "1", // strideX 112 "1", // strideY 113 "UINT8") // dataType 114 {} 115 }; 116 117 TEST_CASE_FIXTURE(SimpleTransposeConvFixture, "ParseSimpleTransposeConv") 118 { 119 RunTest<4, armnn::DataType::QAsymmU8>( 120 0, 121 { 122 1, 2, 123 3, 4 124 }, 125 { 126 0, 1, 2, 127 2, 11, 12, 128 6, 20, 16 129 }); 130 } 131 132 struct TransposeConvFixtureWithBias : public ParserFlatbuffersFixture 133 { TransposeConvFixtureWithBiasTransposeConvFixtureWithBias134 explicit TransposeConvFixtureWithBias(const std::string& inputShape, 135 const std::string& outputShape, 136 const std::string& filterShape, 137 const std::string& filterData, 138 const std::string& strideX, 139 const std::string& strideY, 140 const std::string& dataType, 141 const std::string& biasShape, 142 const std::string& biasData) 143 { 144 m_JsonString = R"( 145 { 146 "version": 3, 147 "operator_codes": [ { "builtin_code": "TRANSPOSE_CONV" } ], 148 "subgraphs": [ { 149 "tensors": [ 150 { 151 "shape": [ 4 ], 152 "type": "UINT8", 153 "buffer": 0, 154 "name": "outputShapeTensor", 155 "quantization": { 156 "min": [ 0.0 ], 157 "max": [ 255.0 ], 158 "scale": [ 1.0 ], 159 "zero_point": [ 0 ], 160 } 161 }, 162 { 163 "shape": )" + filterShape + R"(, 164 "type": ")" + dataType + R"(", 165 "buffer": 1, 166 "name": "filterTensor", 167 "quantization": { 168 "min": [ 0.0 ], 169 "max": [ 255.0 ], 170 "scale": [ 1.0 ], 171 "zero_point": [ 0 ], 172 } 173 }, 174 { 175 "shape": )" + inputShape + R"(, 176 "type": ")" + dataType + R"(", 177 "buffer": 2, 178 "name": "inputTensor", 179 "quantization": { 180 "min": [ 0.0 ], 181 "max": [ 255.0 ], 182 "scale": [ 1.0 ], 183 "zero_point": [ 0 ], 184 } 185 }, 186 { 187 "shape": )" + biasShape + R"( , 188 "type": "INT32", 189 "buffer": 3, 190 "name": "biasTensor", 191 "quantization": { 192 "min": [ 0.0 ], 193 "max": [ 255.0 ], 194 "scale": [ 1.0 ], 195 "zero_point": [ 0 ], 196 } 197 }, 198 { 199 "shape": )" + outputShape + R"(, 200 "type": ")" + dataType + R"(", 201 "buffer": 4, 202 "name": "outputTensor", 203 "quantization": { 204 "min": [ 0.0 ], 205 "max": [ 255.0 ], 206 "scale": [ 1.0 ], 207 "zero_point": [ 0 ], 208 } 209 } 210 ], 211 "inputs": [ 2 ], 212 "outputs": [ 4 ], 213 "operators": [ 214 { 215 "opcode_index": 0, 216 "inputs": [ 0, 1, 2, 3], 217 "outputs": [ 4 ], 218 "builtin_options_type": "TransposeConvOptions", 219 "builtin_options": { 220 "padding": "VALID", 221 "stride_w": )" + strideX + R"(, 222 "stride_h": )" + strideY + R"( 223 }, 224 "custom_options_format": "FLEXBUFFERS" 225 } 226 ], 227 } ], 228 "buffers" : [ 229 { "data": )" + outputShape + R"( }, 230 { "data": )" + filterData + R"( }, 231 { }, 232 { "data": )" + biasData + R"( }, 233 { } 234 ] 235 } 236 )"; 237 SetupSingleInputSingleOutput("inputTensor", "outputTensor"); 238 } 239 }; 240 241 struct SimpleTransposeConvFixtureWithBias : TransposeConvFixtureWithBias 242 { SimpleTransposeConvFixtureWithBiasSimpleTransposeConvFixtureWithBias243 SimpleTransposeConvFixtureWithBias() 244 : TransposeConvFixtureWithBias("[ 1, 2, 2, 1 ]", // inputShape 245 "[ 1, 3, 3, 1 ]", // outputShape 246 "[ 1, 2, 2, 1 ]", // filterShape 247 "[ 0, 1, 2, 4 ]", // filterData 248 "1", // strideX 249 "1", // strideY 250 "UINT8", // dataType 251 "[ 1 ]", // bias shape 252 "[ 10, 0, 0, 0 ]") // bias data 253 {} 254 }; 255 256 TEST_CASE_FIXTURE(SimpleTransposeConvFixtureWithBias, "ParseSimpleTransposeConvWithBias") 257 { 258 RunTest<4, armnn::DataType::QAsymmU8>( 259 0, 260 { 261 1, 2, 262 3, 4 263 }, 264 { 265 10, 11, 12, 266 12, 21, 22, 267 16, 30, 26 268 }); 269 } 270 271 272 struct TransposeConvPerChannelFixture : public ParserFlatbuffersFixture 273 { TransposeConvPerChannelFixtureTransposeConvPerChannelFixture274 explicit TransposeConvPerChannelFixture() 275 { 276 m_JsonString = R"( 277 { 278 "version": 3, 279 "operator_codes": [ 280 { 281 "builtin_code": "TRANSPOSE_CONV", 282 "version": 2 283 } 284 ], 285 "subgraphs": [ 286 { 287 "tensors": [ 288 { 289 "shape": [ 290 1, 291 4, 292 4, 293 2 294 ], 295 "type": "INT8", 296 "buffer": 1, 297 "name": "input", 298 "quantization": { 299 "min": [ 300 -50.0 301 ], 302 "max": [ 303 49.0 304 ], 305 "scale": [ 306 0.388235 307 ], 308 "zero_point": [ 309 1 310 ], 311 "details_type": "NONE", 312 "quantized_dimension": 0 313 }, 314 "is_variable": false 315 }, 316 { 317 "shape": [ 318 4 319 ], 320 "type": "INT32", 321 "buffer": 2, 322 "name": "model/conv2d_transpose/stack", 323 "quantization": { 324 "details_type": "NONE", 325 "quantized_dimension": 0 326 }, 327 "is_variable": false 328 }, 329 { 330 "shape": [ 331 8, 332 2, 333 2, 334 2 335 ], 336 "type": "INT8", 337 "buffer": 3, 338 "name": "model/conv2d_transpose/conv2d_transpose", 339 "quantization": { 340 "min": [ 341 -0.081948, 342 -0.379918, 343 -0.223632, 344 -0.098629, 345 -0.386369, 346 -0.351057, 347 -0.348749, 348 -0.264848 349 ], 350 "max": [ 351 0.35091, 352 0.229681, 353 0.368384, 354 0.176761, 355 0.353717, 356 0.377565, 357 0.373713, 358 0.30141 359 ], 360 "scale": [ 361 0.002763, 362 0.002991, 363 0.002901, 364 0.001392, 365 0.003042, 366 0.002973, 367 0.002943, 368 0.002373 369 ], 370 "zero_point": [ 371 0, 372 0, 373 0, 374 0, 375 0, 376 0, 377 0, 378 0 379 ], 380 "details_type": "NONE", 381 "quantized_dimension": 0 382 }, 383 "is_variable": false 384 }, 385 { 386 "shape": [ 387 1, 388 4, 389 4, 390 8 391 ], 392 "type": "INT8", 393 "buffer": 4, 394 "name": "Identity", 395 "quantization": { 396 "min": [ 397 -63.578175 398 ], 399 "max": [ 400 69.305023 401 ], 402 "scale": [ 403 0.521111 404 ], 405 "zero_point": [ 406 -6 407 ], 408 "details_type": "NONE", 409 "quantized_dimension": 0 410 }, 411 "is_variable": false 412 } 413 ], 414 "inputs": [ 415 0 416 ], 417 "outputs": [ 418 3 419 ], 420 "operators": [ 421 { 422 "opcode_index": 0, 423 "inputs": [ 424 1, 425 2, 426 0 427 ], 428 "outputs": [ 429 3 430 ], 431 "builtin_options_type": "TransposeConvOptions", 432 "builtin_options": { 433 "padding": "SAME", 434 "stride_w": 1, 435 "stride_h": 1 436 }, 437 "custom_options_format": "FLEXBUFFERS" 438 } 439 ], 440 "name": "main" 441 } 442 ], 443 "description": "MLIR Converted.", 444 "buffers": [ 445 { 446 }, 447 { 448 }, 449 { 450 "data": [ 451 1, 452 0, 453 0, 454 0, 455 4, 456 0, 457 0, 458 0, 459 4, 460 0, 461 0, 462 0, 463 8, 464 0, 465 0, 466 0 467 ] 468 }, 469 { 470 "data": [ 471 13, 472 239, 473 7, 474 125, 475 35, 476 127, 477 55, 478 226, 479 77, 480 150, 481 159, 482 192, 483 180, 484 129, 485 51, 486 48, 487 108, 488 9, 489 21, 490 179, 491 12, 492 39, 493 127, 494 107, 495 44, 496 206, 497 127, 498 185, 499 108, 500 82, 501 86, 502 218, 503 38, 504 149, 505 16, 506 1, 507 129, 508 163, 509 116, 510 136, 511 138, 512 43, 513 65, 514 186, 515 154, 516 138, 517 64, 518 127, 519 120, 520 127, 521 207, 522 70, 523 43, 524 33, 525 141, 526 137, 527 93, 528 215, 529 65, 530 92, 531 122, 532 144, 533 120, 534 127 535 ] 536 }, 537 { 538 }, 539 { 540 "data": [ 541 49, 542 46, 543 57, 544 46, 545 48, 546 0, 547 0, 548 0, 549 0, 550 0, 551 0, 552 0, 553 0, 554 0, 555 0, 556 0 557 ] 558 } 559 ], 560 "metadata": [ 561 { 562 "name": "min_runtime_version", 563 "buffer": 5 564 } 565 ] 566 } 567 )"; 568 SetupSingleInputSingleOutput("input", "Identity"); 569 } 570 }; 571 572 TEST_CASE_FIXTURE(TransposeConvPerChannelFixture, "ParseTransposeConvPerChannel") 573 { 574 RunTest<4, armnn::DataType::QAsymmS8>( 575 0, 576 { 577 -11, 40,-26, 11,-28, 8, 0, -8, 578 -10, 34, 47, 0,-33,-14, 28, 35, 579 6,-28,-26, 8, 13, 33,-31,-41, 580 31,-20,-31,-16, 8,-18,-44, 0 581 }, 582 { 583 -8,-17, -8, -9,-16, 1, 2,-11, 584 3,-16,-19,-12,-11, -6, -3, -6, 585 -5, -8,-16,-12,-11, -3, -7,-13, 586 -4, 1, -9,-10, -5,-12, -5, -8, 587 2,-25, -5, -6,-20, -7, 2,-21, 588 1, 4, 5,-13,-10,-12, 3, 4, 589 -10,-17,-17, -6, -7, 12,-22,-17, 590 -17, 0, -5,-14,-21,-12, 17,-13, 591 3, -6, -3, -3, -2,-16,-11,-12, 592 -15,-14, -1, -2,-35, 5,-18, 0, 593 -6, 8, 5,-12, 12, 7, -6, -3, 594 11,-28,-28, -3,-18,-29, -5,-13, 595 -12, 11, -2, -5, 6, -9, -6, 7, 596 -9,-11,-14, -2, 12, 5,-21,-23, 597 -4, -4, -6, -6,-21,-25, 0,-18, 598 -26, 10, -7,-13, 3, 39,-39, -4 599 }); 600 } 601 602 } 603