xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/StridedSlice.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 
9 TEST_SUITE("TensorflowLiteParser_StridedSlice")
10 {
11 struct StridedSliceFixture : public ParserFlatbuffersFixture
12 {
StridedSliceFixtureStridedSliceFixture13     explicit StridedSliceFixture(const std::string & inputShape,
14                                  const std::string & outputShape,
15                                  const std::string & beginData,
16                                  const std::string & endData,
17                                  const std::string & stridesData,
18                                  int beginMask = 0,
19                                  int endMask = 0)
20     {
21         m_JsonString = R"(
22             {
23                 "version": 3,
24                 "operator_codes": [ { "builtin_code": "STRIDED_SLICE" } ],
25                 "subgraphs": [ {
26                     "tensors": [
27                         {
28                             "shape": )" + inputShape + R"(,
29                             "type": "FLOAT32",
30                             "buffer": 0,
31                             "name": "inputTensor",
32                             "quantization": {
33                                 "min": [ 0.0 ],
34                                 "max": [ 255.0 ],
35                                 "scale": [ 1.0 ],
36                                 "zero_point": [ 0 ],
37                             }
38                         },
39                         {
40                             "shape": [ 4 ],
41                             "type": "INT32",
42                             "buffer": 1,
43                             "name": "beginTensor",
44                             "quantization": {
45                             }
46                         },
47                         {
48                            "shape": [ 4 ],
49                             "type": "INT32",
50                             "buffer": 2,
51                             "name": "endTensor",
52                             "quantization": {
53                             }
54                         },
55                         {
56                            "shape": [ 4 ],
57                             "type": "INT32",
58                             "buffer": 3,
59                             "name": "stridesTensor",
60                             "quantization": {
61                             }
62                         },
63                         {
64                             "shape": )" + outputShape + R"( ,
65                             "type": "FLOAT32",
66                             "buffer": 4,
67                             "name": "outputTensor",
68                             "quantization": {
69                                 "min": [ 0.0 ],
70                                 "max": [ 255.0 ],
71                                 "scale": [ 1.0 ],
72                                 "zero_point": [ 0 ],
73                             }
74                         }
75                     ],
76                     "inputs": [ 0, 1, 2, 3 ],
77                     "outputs": [ 4 ],
78                     "operators": [
79                         {
80                             "opcode_index": 0,
81                             "inputs": [ 0, 1, 2, 3 ],
82                             "outputs": [ 4 ],
83                             "builtin_options_type": "StridedSliceOptions",
84                             "builtin_options": {
85                                "begin_mask": )"       + std::to_string(beginMask)      + R"(,
86                                "end_mask": )"         + std::to_string(endMask)        + R"(
87                             },
88                             "custom_options_format": "FLEXBUFFERS"
89                         }
90                     ],
91                 } ],
92                 "buffers" : [
93                     { },
94                     { "data": )" + beginData + R"(, },
95                     { "data": )" + endData + R"(, },
96                     { "data": )" + stridesData + R"(, },
97                     { }
98                 ]
99             }
100         )";
101         Setup();
102     }
103 };
104 
105 struct StridedSlice4DFixture : StridedSliceFixture
106 {
StridedSlice4DFixtureStridedSlice4DFixture107     StridedSlice4DFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",  // inputShape
108                                                   "[ 1, 2, 3, 1 ]",  // outputShape
109                                                   "[ 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]",  // beginData
110                                                   "[ 2,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]",  // endData
111                                                   "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]"   // stridesData
112                                                  ) {}
113 };
114 
115 TEST_CASE_FIXTURE(StridedSlice4DFixture, "StridedSlice4D")
116 {
117   RunTest<4, armnn::DataType::Float32>(
118       0,
119       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
120 
121                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
122 
123                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
124 
125       {{"outputTensor", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
126 }
127 
128 struct StridedSlice4DReverseFixture : StridedSliceFixture
129 {
StridedSlice4DReverseFixtureStridedSlice4DReverseFixture130     StridedSlice4DReverseFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",    // inputShape
131                                                          "[ 1, 2, 3, 1 ]",    // outputShape
132                                                          "[ 1,0,0,0, "
133                                                          "255,255,255,255, "
134                                                          "0,0,0,0, "
135                                                          "0,0,0,0 ]",  // beginData    [ 1 -1 0 0 ]
136                                                          "[ 2,0,0,0, "
137                                                          "253,255,255,255, "
138                                                          "3,0,0,0, "
139                                                          "1,0,0,0 ]",  // endData      [ 2 -3 3 1 ]
140                                                          "[ 1,0,0,0, "
141                                                          "255,255,255,255, "
142                                                          "1,0,0,0, "
143                                                          "1,0,0,0 ]"   // stridesData  [ 1 -1 1 1 ]
144                                                         ) {}
145 };
146 
147 TEST_CASE_FIXTURE(StridedSlice4DReverseFixture, "StridedSlice4DReverse")
148 {
149   RunTest<4, armnn::DataType::Float32>(
150       0,
151       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
152 
153                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
154 
155                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
156 
157       {{"outputTensor", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
158 }
159 
160 struct StridedSliceSimpleStrideFixture : StridedSliceFixture
161 {
StridedSliceSimpleStrideFixtureStridedSliceSimpleStrideFixture162     StridedSliceSimpleStrideFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",  // inputShape
163                                                             "[ 2, 1, 2, 1 ]",  // outputShape
164                                                             "[ 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]",  // beginData
165                                                             "[ 3,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]",  // endData
166                                                             "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 1,0,0,0 ]"   // stridesData
167                                                  ) {}
168 };
169 
170 TEST_CASE_FIXTURE(StridedSliceSimpleStrideFixture, "StridedSliceSimpleStride")
171 {
172   RunTest<4, armnn::DataType::Float32>(
173       0,
174       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
175 
176                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
177 
178                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
179 
180       {{"outputTensor", { 1.0f, 1.0f,
181 
182                           5.0f, 5.0f }}});
183 }
184 
185 struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
186 {
StridedSliceSimpleRangeMaskFixtureStridedSliceSimpleRangeMaskFixture187     StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",  // inputShape
188                                                                "[ 3, 2, 3, 1 ]",  // outputShape
189                                                                "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",  // beginData
190                                                                "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",  // endData
191                                                                "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",  // stridesData
192                                                                (1 << 4) - 1,  // beginMask
193                                                                (1 << 4) - 1   // endMask
194                                                  ) {}
195 };
196 
197 TEST_CASE_FIXTURE(StridedSliceSimpleRangeMaskFixture, "StridedSliceSimpleRangeMask")
198 {
199   RunTest<4, armnn::DataType::Float32>(
200       0,
201       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
202 
203                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
204 
205                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
206 
207       {{"outputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
208 
209                           3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
210 
211                           5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
212 }
213 
214 }
215