xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/MirrorPad.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 TEST_SUITE("TensorflowLiteParser_MirrorPad")
9 {
10 struct MirrorPadFixture : public ParserFlatbuffersFixture
11 {
MirrorPadFixtureMirrorPadFixture12     explicit MirrorPadFixture(const std::string& inputShape,
13                               const std::string& outputShape,
14                               const std::string& padListShape,
15                               const std::string& padListData,
16                               const std::string& padMode,
17                               const std::string& dataType = "FLOAT32",
18                               const std::string& scale = "1.0",
19                               const std::string& offset = "0")
20     {
21         m_JsonString = R"(
22             {
23                 "version": 3,
24                 "operator_codes": [ { "builtin_code": "MIRROR_PAD" } ],
25                 "subgraphs": [ {
26                     "tensors": [
27                         {
28                             "shape": )" + inputShape + R"(,
29                             "type": )" + dataType + R"(,
30                             "buffer": 0,
31                             "name": "inputTensor",
32                             "quantization": {
33                                 "min": [ 0.0 ],
34                                 "max": [ 255.0 ],
35                                 "scale": [ )" + scale + R"( ],
36                                 "zero_point": [ )" + offset + R"( ],
37                             }
38                         },
39                         {
40                              "shape": )" + outputShape + R"(,
41                              "type": )" + dataType + R"(,
42                              "buffer": 1,
43                              "name": "outputTensor",
44                              "quantization": {
45                                 "min": [ 0.0 ],
46                                 "max": [ 255.0 ],
47                                 "scale": [ )" + scale + R"( ],
48                                 "zero_point": [ )" + offset + R"( ],
49                             }
50                         },
51                         {
52                              "shape": )" + padListShape + R"( ,
53                              "type": "INT32",
54                              "buffer": 2,
55                              "name": "padList",
56                              "quantization": {
57                                 "min": [ 0.0 ],
58                                 "max": [ 255.0 ],
59                                 "scale": [ 1.0 ],
60                                 "zero_point": [ 0 ],
61                              }
62                         }
63                     ],
64                     "inputs": [ 0 ],
65                     "outputs": [ 1 ],
66                     "operators": [
67                         {
68                             "opcode_index": 0,
69                             "inputs": [ 0, 2 ],
70                             "outputs": [ 1 ],
71                             "builtin_options_type": "MirrorPadOptions",
72                             "builtin_options": {
73                               "mode": )" + padMode + R"( ,
74                             },
75                             "custom_options_format": "FLEXBUFFERS"
76                         }
77                     ],
78                 } ],
79                 "buffers" : [
80                     { },
81                     { },
82                     { "data": )" + padListData + R"(, },
83                 ]
84             }
85         )";
86       SetupSingleInputSingleOutput("inputTensor", "outputTensor");
87     }
88 };
89 
90 struct SimpleMirrorPadSymmetricFixture : public MirrorPadFixture
91 {
SimpleMirrorPadSymmetricFixtureSimpleMirrorPadSymmetricFixture92     SimpleMirrorPadSymmetricFixture() : MirrorPadFixture("[ 3, 3 ]", "[ 7, 7 ]", "[ 2, 2 ]",
93                                                          "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 2,0,0,0 ]",
94                                                          "SYMMETRIC", "FLOAT32") {}
95 };
96 
97 TEST_CASE_FIXTURE(SimpleMirrorPadSymmetricFixture, "ParseMirrorPadSymmetric")
98 {
99     RunTest<2, armnn::DataType::Float32>
100             (0,
101              {{ "inputTensor",  { 1.0f, 2.0f, 3.0f,
102                                   4.0f, 5.0f, 6.0f,
103                                   7.0f, 8.0f, 9.0f }}},
104 
105              {{ "outputTensor", { 5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f,
106                                   2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f,
107                                   2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f,
108                                   5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f,
109                                   8.0f, 7.0f, 7.0f, 8.0f, 9.0f, 9.0f, 8.0f,
110                                   8.0f, 7.0f, 7.0f, 8.0f, 9.0f, 9.0f, 8.0f,
111                                   5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f }}});
112 }
113 
114 struct SimpleMirrorPadReflectFixture : public MirrorPadFixture
115 {
SimpleMirrorPadReflectFixtureSimpleMirrorPadReflectFixture116     SimpleMirrorPadReflectFixture() : MirrorPadFixture("[ 3, 3 ]", "[ 7, 7 ]", "[ 2, 2 ]",
117                                                         "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 2,0,0,0 ]",
118                                                         "REFLECT", "FLOAT32") {}
119 };
120 
121 TEST_CASE_FIXTURE(SimpleMirrorPadReflectFixture, "ParseMirrorPadRelfect")
122 {
123     RunTest<2, armnn::DataType::Float32>
124         (0,
125          {{ "inputTensor",  { 1.0f, 2.0f, 3.0f,
126                               4.0f, 5.0f, 6.0f,
127                               7.0f, 8.0f, 9.0f }}},
128 
129          {{ "outputTensor", { 9.0f, 8.0f, 7.0f, 8.0f, 9.0f, 8.0f, 7.0f,
130                               6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f,
131                               3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f,
132                               6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f,
133                               9.0f, 8.0f, 7.0f, 8.0f, 9.0f, 8.0f, 7.0f,
134                               6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f,
135                               3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f }}});
136 }
137 
138 }
139