xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/TransposeConv.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 TEST_SUITE("TensorflowLiteParser_TransposeConv")
9 {
10 struct TransposeConvFixture : public ParserFlatbuffersFixture
11 {
TransposeConvFixtureTransposeConvFixture12     explicit TransposeConvFixture(const std::string& inputShape,
13                                   const std::string& outputShape,
14                                   const std::string& filterShape,
15                                   const std::string& filterData,
16                                   const std::string& strideX,
17                                   const std::string& strideY,
18                                   const std::string& dataType)
19     {
20         m_JsonString = R"(
21             {
22                 "version": 3,
23                 "operator_codes": [ { "builtin_code": "TRANSPOSE_CONV" } ],
24                 "subgraphs": [ {
25                     "tensors": [
26                         {
27                             "shape": [ 4 ],
28                             "type": "UINT8",
29                             "buffer": 0,
30                             "name": "outputShapeTensor",
31                             "quantization": {
32                                 "min": [ 0.0 ],
33                                 "max": [ 255.0 ],
34                                 "scale": [ 1.0 ],
35                                 "zero_point": [ 0 ],
36                             }
37                         },
38                         {
39                             "shape": )" + filterShape + R"(,
40                             "type": ")" + dataType + R"(",
41                             "buffer": 1,
42                             "name": "filterTensor",
43                             "quantization": {
44                                 "min": [ 0.0 ],
45                                 "max": [ 255.0 ],
46                                 "scale": [ 1.0 ],
47                                 "zero_point": [ 0 ],
48                             }
49                         },
50                         {
51                             "shape": )" + inputShape + R"(,
52                             "type": ")" + dataType + R"(",
53                             "buffer": 2,
54                             "name": "inputTensor",
55                             "quantization": {
56                                 "min": [ 0.0 ],
57                                 "max": [ 255.0 ],
58                                 "scale": [ 1.0 ],
59                                 "zero_point": [ 0 ],
60                             }
61                         },
62                         {
63                             "shape": )" + outputShape + R"(,
64                             "type": ")" + dataType + R"(",
65                             "buffer": 3,
66                             "name": "outputTensor",
67                             "quantization": {
68                                 "min": [ 0.0 ],
69                                 "max": [ 255.0 ],
70                                 "scale": [ 1.0 ],
71                                 "zero_point": [ 0 ],
72                             }
73                         }
74                     ],
75                     "inputs": [ 2 ],
76                     "outputs": [ 3 ],
77                     "operators": [
78                         {
79                             "opcode_index": 0,
80                             "inputs": [ 0, 1, 2 ],
81                             "outputs": [ 3 ],
82                             "builtin_options_type": "TransposeConvOptions",
83                             "builtin_options": {
84                                 "padding": "VALID",
85                                 "stride_w": )" + strideX + R"(,
86                                 "stride_h": )" + strideY + R"(
87                             },
88                             "custom_options_format": "FLEXBUFFERS"
89                         }
90                     ],
91                 } ],
92                 "buffers" : [
93                     { "data": )" + outputShape + R"( },
94                     { "data": )" + filterData + R"( },
95                     { },
96                     { }
97                 ]
98             }
99         )";
100         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
101     }
102 };
103 
104 struct SimpleTransposeConvFixture : TransposeConvFixture
105 {
SimpleTransposeConvFixtureSimpleTransposeConvFixture106     SimpleTransposeConvFixture()
107     : TransposeConvFixture("[ 1, 2, 2, 1 ]",  // inputShape
108                            "[ 1, 3, 3, 1 ]",  // outputShape
109                            "[ 1, 2, 2, 1 ]",  // filterShape
110                            "[ 0, 1, 2, 4 ]",  // filterData
111                            "1",               // strideX
112                            "1",               // strideY
113                            "UINT8")           // dataType
114     {}
115 };
116 
117 TEST_CASE_FIXTURE(SimpleTransposeConvFixture, "ParseSimpleTransposeConv")
118 {
119     RunTest<4, armnn::DataType::QAsymmU8>(
120         0,
121         {
122             1, 2,
123             3, 4
124         },
125         {
126             0, 1,  2,
127             2, 11, 12,
128             6, 20, 16
129         });
130 }
131 
132 struct TransposeConvFixtureWithBias : public ParserFlatbuffersFixture
133 {
TransposeConvFixtureWithBiasTransposeConvFixtureWithBias134     explicit TransposeConvFixtureWithBias(const std::string& inputShape,
135                                           const std::string& outputShape,
136                                           const std::string& filterShape,
137                                           const std::string& filterData,
138                                           const std::string& strideX,
139                                           const std::string& strideY,
140                                           const std::string& dataType,
141                                           const std::string& biasShape,
142                                           const std::string& biasData)
143     {
144         m_JsonString = R"(
145             {
146                 "version": 3,
147                 "operator_codes": [ { "builtin_code": "TRANSPOSE_CONV" } ],
148                 "subgraphs": [ {
149                     "tensors": [
150                         {
151                             "shape": [ 4 ],
152                             "type": "UINT8",
153                             "buffer": 0,
154                             "name": "outputShapeTensor",
155                             "quantization": {
156                                 "min": [ 0.0 ],
157                                 "max": [ 255.0 ],
158                                 "scale": [ 1.0 ],
159                                 "zero_point": [ 0 ],
160                             }
161                         },
162                         {
163                             "shape": )" + filterShape + R"(,
164                             "type": ")" + dataType + R"(",
165                             "buffer": 1,
166                             "name": "filterTensor",
167                             "quantization": {
168                                 "min": [ 0.0 ],
169                                 "max": [ 255.0 ],
170                                 "scale": [ 1.0 ],
171                                 "zero_point": [ 0 ],
172                             }
173                         },
174                         {
175                             "shape": )" + inputShape + R"(,
176                             "type": ")" + dataType + R"(",
177                             "buffer": 2,
178                             "name": "inputTensor",
179                             "quantization": {
180                                 "min": [ 0.0 ],
181                                 "max": [ 255.0 ],
182                                 "scale": [ 1.0 ],
183                                 "zero_point": [ 0 ],
184                             }
185                         },
186                         {
187                             "shape": )" + biasShape + R"( ,
188                             "type": "INT32",
189                             "buffer": 3,
190                             "name": "biasTensor",
191                             "quantization": {
192                                 "min": [ 0.0 ],
193                                 "max": [ 255.0 ],
194                                 "scale": [ 1.0 ],
195                                 "zero_point": [ 0 ],
196                             }
197                         },
198                         {
199                             "shape": )" + outputShape + R"(,
200                             "type": ")" + dataType + R"(",
201                             "buffer": 4,
202                             "name": "outputTensor",
203                             "quantization": {
204                                 "min": [ 0.0 ],
205                                 "max": [ 255.0 ],
206                                 "scale": [ 1.0 ],
207                                 "zero_point": [ 0 ],
208                             }
209                         }
210                     ],
211                     "inputs": [ 2 ],
212                     "outputs": [ 4 ],
213                     "operators": [
214                         {
215                             "opcode_index": 0,
216                             "inputs": [ 0, 1, 2, 3],
217                             "outputs": [ 4 ],
218                             "builtin_options_type": "TransposeConvOptions",
219                             "builtin_options": {
220                                 "padding": "VALID",
221                                 "stride_w": )" + strideX + R"(,
222                                 "stride_h": )" + strideY + R"(
223                             },
224                             "custom_options_format": "FLEXBUFFERS"
225                         }
226                     ],
227                 } ],
228                 "buffers" : [
229                     { "data": )" + outputShape + R"( },
230                     { "data": )" + filterData + R"( },
231                     { },
232                     { "data": )" + biasData + R"( },
233                     { }
234                 ]
235             }
236         )";
237         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
238     }
239 };
240 
241 struct SimpleTransposeConvFixtureWithBias : TransposeConvFixtureWithBias
242 {
SimpleTransposeConvFixtureWithBiasSimpleTransposeConvFixtureWithBias243     SimpleTransposeConvFixtureWithBias()
244     : TransposeConvFixtureWithBias("[ 1, 2, 2, 1 ]",  // inputShape
245                                    "[ 1, 3, 3, 1 ]",  // outputShape
246                                    "[ 1, 2, 2, 1 ]",  // filterShape
247                                    "[ 0, 1, 2, 4 ]",  // filterData
248                                    "1",               // strideX
249                                    "1",               // strideY
250                                    "UINT8",           // dataType
251                                    "[ 1 ]",           // bias shape
252                                    "[ 10, 0, 0, 0 ]") // bias data
253     {}
254 };
255 
256 TEST_CASE_FIXTURE(SimpleTransposeConvFixtureWithBias, "ParseSimpleTransposeConvWithBias")
257 {
258     RunTest<4, armnn::DataType::QAsymmU8>(
259         0,
260         {
261             1, 2,
262             3, 4
263         },
264         {
265             10, 11, 12,
266             12, 21, 22,
267             16, 30, 26
268         });
269 }
270 
271 
272 struct TransposeConvPerChannelFixture : public ParserFlatbuffersFixture
273 {
TransposeConvPerChannelFixtureTransposeConvPerChannelFixture274     explicit TransposeConvPerChannelFixture()
275     {
276         m_JsonString = R"(
277         {
278             "version": 3,
279             "operator_codes": [
280                 {
281                     "builtin_code": "TRANSPOSE_CONV",
282                     "version": 2
283                 }
284             ],
285             "subgraphs": [
286                 {
287                     "tensors": [
288                         {
289                             "shape": [
290                                 1,
291                                 4,
292                                 4,
293                                 2
294                             ],
295                             "type": "INT8",
296                             "buffer": 1,
297                             "name": "input",
298                             "quantization": {
299                                 "min": [
300                                     -50.0
301                                 ],
302                                 "max": [
303                                     49.0
304                                 ],
305                                 "scale": [
306                                     0.388235
307                                 ],
308                                 "zero_point": [
309                                     1
310                                 ],
311                                 "details_type": "NONE",
312                                 "quantized_dimension": 0
313                             },
314                             "is_variable": false
315                         },
316                         {
317                             "shape": [
318                                 4
319                             ],
320                             "type": "INT32",
321                             "buffer": 2,
322                             "name": "model/conv2d_transpose/stack",
323                             "quantization": {
324                                 "details_type": "NONE",
325                                 "quantized_dimension": 0
326                             },
327                             "is_variable": false
328                         },
329                         {
330                             "shape": [
331                                 8,
332                                 2,
333                                 2,
334                                 2
335                             ],
336                             "type": "INT8",
337                             "buffer": 3,
338                             "name": "model/conv2d_transpose/conv2d_transpose",
339                             "quantization": {
340                                 "min": [
341                                     -0.081948,
342                                     -0.379918,
343                                     -0.223632,
344                                     -0.098629,
345                                     -0.386369,
346                                     -0.351057,
347                                     -0.348749,
348                                     -0.264848
349                                 ],
350                                 "max": [
351                                     0.35091,
352                                     0.229681,
353                                     0.368384,
354                                     0.176761,
355                                     0.353717,
356                                     0.377565,
357                                     0.373713,
358                                     0.30141
359                                 ],
360                                 "scale": [
361                                     0.002763,
362                                     0.002991,
363                                     0.002901,
364                                     0.001392,
365                                     0.003042,
366                                     0.002973,
367                                     0.002943,
368                                     0.002373
369                                 ],
370                                 "zero_point": [
371                                     0,
372                                     0,
373                                     0,
374                                     0,
375                                     0,
376                                     0,
377                                     0,
378                                     0
379                                 ],
380                                 "details_type": "NONE",
381                                 "quantized_dimension": 0
382                             },
383                             "is_variable": false
384                         },
385                         {
386                             "shape": [
387                                 1,
388                                 4,
389                                 4,
390                                 8
391                             ],
392                             "type": "INT8",
393                             "buffer": 4,
394                             "name": "Identity",
395                             "quantization": {
396                                 "min": [
397                                     -63.578175
398                                 ],
399                                 "max": [
400                                     69.305023
401                                 ],
402                                 "scale": [
403                                     0.521111
404                                 ],
405                                 "zero_point": [
406                                     -6
407                                 ],
408                                 "details_type": "NONE",
409                                 "quantized_dimension": 0
410                             },
411                             "is_variable": false
412                         }
413                     ],
414                     "inputs": [
415                         0
416                     ],
417                     "outputs": [
418                         3
419                     ],
420                     "operators": [
421                         {
422                             "opcode_index": 0,
423                             "inputs": [
424                                 1,
425                                 2,
426                                 0
427                             ],
428                             "outputs": [
429                                 3
430                             ],
431                             "builtin_options_type": "TransposeConvOptions",
432                             "builtin_options": {
433                                 "padding": "SAME",
434                                 "stride_w": 1,
435                                 "stride_h": 1
436                             },
437                             "custom_options_format": "FLEXBUFFERS"
438                         }
439                     ],
440                     "name": "main"
441                 }
442             ],
443             "description": "MLIR Converted.",
444             "buffers": [
445                 {
446                 },
447                 {
448                 },
449                 {
450                     "data": [
451                         1,
452                         0,
453                         0,
454                         0,
455                         4,
456                         0,
457                         0,
458                         0,
459                         4,
460                         0,
461                         0,
462                         0,
463                         8,
464                         0,
465                         0,
466                         0
467                     ]
468                 },
469                 {
470                     "data": [
471                         13,
472                         239,
473                         7,
474                         125,
475                         35,
476                         127,
477                         55,
478                         226,
479                         77,
480                         150,
481                         159,
482                         192,
483                         180,
484                         129,
485                         51,
486                         48,
487                         108,
488                         9,
489                         21,
490                         179,
491                         12,
492                         39,
493                         127,
494                         107,
495                         44,
496                         206,
497                         127,
498                         185,
499                         108,
500                         82,
501                         86,
502                         218,
503                         38,
504                         149,
505                         16,
506                         1,
507                         129,
508                         163,
509                         116,
510                         136,
511                         138,
512                         43,
513                         65,
514                         186,
515                         154,
516                         138,
517                         64,
518                         127,
519                         120,
520                         127,
521                         207,
522                         70,
523                         43,
524                         33,
525                         141,
526                         137,
527                         93,
528                         215,
529                         65,
530                         92,
531                         122,
532                         144,
533                         120,
534                         127
535                     ]
536                 },
537                 {
538                 },
539                 {
540                     "data": [
541                         49,
542                         46,
543                         57,
544                         46,
545                         48,
546                         0,
547                         0,
548                         0,
549                         0,
550                         0,
551                         0,
552                         0,
553                         0,
554                         0,
555                         0,
556                         0
557                     ]
558                 }
559               ],
560             "metadata": [
561                 {
562                     "name": "min_runtime_version",
563                     "buffer": 5
564                 }
565             ]
566         }
567         )";
568         SetupSingleInputSingleOutput("input", "Identity");
569     }
570 };
571 
572 TEST_CASE_FIXTURE(TransposeConvPerChannelFixture, "ParseTransposeConvPerChannel")
573 {
574     RunTest<4, armnn::DataType::QAsymmS8>(
575         0,
576         {
577             -11, 40,-26, 11,-28,  8,  0, -8,
578             -10, 34, 47,  0,-33,-14, 28, 35,
579               6,-28,-26,  8, 13, 33,-31,-41,
580              31,-20,-31,-16,  8,-18,-44,  0
581         },
582         {
583             -8,-17, -8, -9,-16,  1,  2,-11,
584              3,-16,-19,-12,-11, -6, -3, -6,
585             -5, -8,-16,-12,-11, -3, -7,-13,
586             -4,  1, -9,-10, -5,-12, -5, -8,
587              2,-25, -5, -6,-20, -7,  2,-21,
588              1,  4,  5,-13,-10,-12,  3,  4,
589            -10,-17,-17, -6, -7, 12,-22,-17,
590            -17,  0, -5,-14,-21,-12, 17,-13,
591              3, -6, -3, -3, -2,-16,-11,-12,
592            -15,-14, -1, -2,-35,  5,-18,  0,
593             -6,  8,  5,-12, 12,  7, -6, -3,
594             11,-28,-28, -3,-18,-29, -5,-13,
595            -12, 11, -2, -5,  6, -9, -6,  7,
596             -9,-11,-14, -2, 12,  5,-21,-23,
597             -4, -4, -6, -6,-21,-25,  0,-18,
598            -26, 10, -7,-13,  3, 39,-39, -4
599         });
600 }
601 
602 }
603