xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Pad.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 
9 TEST_SUITE("TensorflowLiteParser_Pad")
10 {
11 struct PadFixture : public ParserFlatbuffersFixture
12 {
PadFixturePadFixture13     explicit PadFixture(const std::string& inputShape,
14                         const std::string& outputShape,
15                         const std::string& padListShape,
16                         const std::string& padListData,
17                         const std::string& dataType = "FLOAT32",
18                         const std::string& scale = "1.0",
19                         const std::string& offset = "0")
20     {
21         m_JsonString = R"(
22             {
23                 "version": 3,
24                 "operator_codes": [ { "builtin_code": "PAD" } ],
25                 "subgraphs": [ {
26                     "tensors": [
27                         {
28                             "shape": )" + inputShape + R"(,
29                             "type": )" + dataType + R"(,
30                             "buffer": 0,
31                             "name": "inputTensor",
32                             "quantization": {
33                                 "min": [ 0.0 ],
34                                 "max": [ 255.0 ],
35                                 "scale": [ )" + scale + R"( ],
36                                 "zero_point": [ )" + offset + R"( ],
37                             }
38                         },
39                         {
40                              "shape": )" + outputShape + R"(,
41                              "type": )" + dataType + R"(,
42                              "buffer": 1,
43                              "name": "outputTensor",
44                              "quantization": {
45                                 "min": [ 0.0 ],
46                                 "max": [ 255.0 ],
47                                 "scale": [ )" + scale + R"( ],
48                                 "zero_point": [ )" + offset + R"( ],
49                             }
50                         },
51                         {
52                              "shape": )" + padListShape + R"( ,
53                              "type": "INT32",
54                              "buffer": 2,
55                              "name": "padList",
56                              "quantization": {
57                                 "min": [ 0.0 ],
58                                 "max": [ 255.0 ],
59                                 "scale": [ 1.0 ],
60                                 "zero_point": [ 0 ],
61                              }
62                         }
63                     ],
64                     "inputs": [ 0 ],
65                     "outputs": [ 1 ],
66                     "operators": [
67                         {
68                             "opcode_index": 0,
69                             "inputs": [ 0, 2 ],
70                             "outputs": [ 1 ],
71                             "custom_options_format": "FLEXBUFFERS"
72                         }
73                     ],
74                 } ],
75                 "buffers" : [
76                     { },
77                     { },
78                     { "data": )" + padListData + R"(, },
79                 ]
80             }
81         )";
82       SetupSingleInputSingleOutput("inputTensor", "outputTensor");
83     }
84 };
85 
86 struct SimplePadFixture : public PadFixture
87 {
SimplePadFixtureSimplePadFixture88     SimplePadFixture() : PadFixture("[ 2, 3 ]", "[ 4, 7 ]", "[ 2, 2 ]",
89                                     "[  1,0,0,0, 1,0,0,0, 2,0,0,0, 2,0,0,0 ]") {}
90 };
91 
92 TEST_CASE_FIXTURE(SimplePadFixture, "ParsePad")
93 {
94     RunTest<2, armnn::DataType::Float32>
95         (0,
96          {{ "inputTensor",  { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
97          {{ "outputTensor", { 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
98                               0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f,
99                               0.0f, 0.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.0f,
100                               0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f }}});
101 }
102 
103 struct Uint8PadFixture : public PadFixture
104 {
Uint8PadFixtureUint8PadFixture105     Uint8PadFixture() : PadFixture("[ 2, 3 ]", "[ 4, 7 ]", "[ 2, 2 ]",
106                                   "[  1,0,0,0, 1,0,0,0, 2,0,0,0, 2,0,0,0 ]",
107                                   "UINT8", "-2.0", "3") {}
108 };
109 
110 TEST_CASE_FIXTURE(Uint8PadFixture, "ParsePadUint8")
111 {
112     RunTest<2, armnn::DataType::QAsymmU8>
113         (0,
114          {{ "inputTensor",  { 1, 2, 3, 4, 5, 6 }}},
115          {{ "outputTensor", { 3, 3, 3, 3, 3, 3, 3,
116                               3, 3, 1, 2, 3, 3, 3,
117                               3, 3, 4, 5, 6, 3, 3,
118                               3, 3, 3, 3, 3, 3, 3 }}});
119 }
120 
121 struct Int8PadFixture : public PadFixture
122 {
Int8PadFixtureInt8PadFixture123     Int8PadFixture() : PadFixture("[ 2, 3 ]", "[ 4, 7 ]", "[ 2, 2 ]",
124                                     "[  1,0,0,0, 1,0,0,0, 2,0,0,0, 2,0,0,0 ]",
125                                     "INT8", "-2.0", "3") {}
126 };
127 
128 TEST_CASE_FIXTURE(Int8PadFixture, "ParsePadInt8")
129 {
130     RunTest<2, armnn::DataType::QAsymmS8>
131         (0,
132          {{ "inputTensor",  { 1, -2, 3, 4, 5, -6 }}},
133          {{ "outputTensor", { 3, 3, 3, 3, 3, 3, 3,
134                               3, 3, 1, -2, 3, 3, 3,
135                               3, 3, 4, 5, -6, 3, 3,
136                               3, 3, 3, 3, 3, 3, 3 }}});
137 }
138 
139 }
140