xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/StridedSlice.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_StridedSlice")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker struct StridedSliceFixture : public ParserFlatbuffersFixture
12*89c4ff92SAndroid Build Coastguard Worker {
StridedSliceFixtureStridedSliceFixture13*89c4ff92SAndroid Build Coastguard Worker     explicit StridedSliceFixture(const std::string & inputShape,
14*89c4ff92SAndroid Build Coastguard Worker                                  const std::string & outputShape,
15*89c4ff92SAndroid Build Coastguard Worker                                  const std::string & beginData,
16*89c4ff92SAndroid Build Coastguard Worker                                  const std::string & endData,
17*89c4ff92SAndroid Build Coastguard Worker                                  const std::string & stridesData,
18*89c4ff92SAndroid Build Coastguard Worker                                  int beginMask = 0,
19*89c4ff92SAndroid Build Coastguard Worker                                  int endMask = 0)
20*89c4ff92SAndroid Build Coastguard Worker     {
21*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
22*89c4ff92SAndroid Build Coastguard Worker             {
23*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
24*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ { "builtin_code": "STRIDED_SLICE" } ],
25*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [ {
26*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [
27*89c4ff92SAndroid Build Coastguard Worker                         {
28*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + inputShape + R"(,
29*89c4ff92SAndroid Build Coastguard Worker                             "type": "FLOAT32",
30*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
31*89c4ff92SAndroid Build Coastguard Worker                             "name": "inputTensor",
32*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
33*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
34*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
35*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
36*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
37*89c4ff92SAndroid Build Coastguard Worker                             }
38*89c4ff92SAndroid Build Coastguard Worker                         },
39*89c4ff92SAndroid Build Coastguard Worker                         {
40*89c4ff92SAndroid Build Coastguard Worker                             "shape": [ 4 ],
41*89c4ff92SAndroid Build Coastguard Worker                             "type": "INT32",
42*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
43*89c4ff92SAndroid Build Coastguard Worker                             "name": "beginTensor",
44*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
45*89c4ff92SAndroid Build Coastguard Worker                             }
46*89c4ff92SAndroid Build Coastguard Worker                         },
47*89c4ff92SAndroid Build Coastguard Worker                         {
48*89c4ff92SAndroid Build Coastguard Worker                            "shape": [ 4 ],
49*89c4ff92SAndroid Build Coastguard Worker                             "type": "INT32",
50*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 2,
51*89c4ff92SAndroid Build Coastguard Worker                             "name": "endTensor",
52*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
53*89c4ff92SAndroid Build Coastguard Worker                             }
54*89c4ff92SAndroid Build Coastguard Worker                         },
55*89c4ff92SAndroid Build Coastguard Worker                         {
56*89c4ff92SAndroid Build Coastguard Worker                            "shape": [ 4 ],
57*89c4ff92SAndroid Build Coastguard Worker                             "type": "INT32",
58*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 3,
59*89c4ff92SAndroid Build Coastguard Worker                             "name": "stridesTensor",
60*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
61*89c4ff92SAndroid Build Coastguard Worker                             }
62*89c4ff92SAndroid Build Coastguard Worker                         },
63*89c4ff92SAndroid Build Coastguard Worker                         {
64*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + outputShape + R"( ,
65*89c4ff92SAndroid Build Coastguard Worker                             "type": "FLOAT32",
66*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 4,
67*89c4ff92SAndroid Build Coastguard Worker                             "name": "outputTensor",
68*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
69*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
70*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
71*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
72*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
73*89c4ff92SAndroid Build Coastguard Worker                             }
74*89c4ff92SAndroid Build Coastguard Worker                         }
75*89c4ff92SAndroid Build Coastguard Worker                     ],
76*89c4ff92SAndroid Build Coastguard Worker                     "inputs": [ 0, 1, 2, 3 ],
77*89c4ff92SAndroid Build Coastguard Worker                     "outputs": [ 4 ],
78*89c4ff92SAndroid Build Coastguard Worker                     "operators": [
79*89c4ff92SAndroid Build Coastguard Worker                         {
80*89c4ff92SAndroid Build Coastguard Worker                             "opcode_index": 0,
81*89c4ff92SAndroid Build Coastguard Worker                             "inputs": [ 0, 1, 2, 3 ],
82*89c4ff92SAndroid Build Coastguard Worker                             "outputs": [ 4 ],
83*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options_type": "StridedSliceOptions",
84*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options": {
85*89c4ff92SAndroid Build Coastguard Worker                                "begin_mask": )"       + std::to_string(beginMask)      + R"(,
86*89c4ff92SAndroid Build Coastguard Worker                                "end_mask": )"         + std::to_string(endMask)        + R"(
87*89c4ff92SAndroid Build Coastguard Worker                             },
88*89c4ff92SAndroid Build Coastguard Worker                             "custom_options_format": "FLEXBUFFERS"
89*89c4ff92SAndroid Build Coastguard Worker                         }
90*89c4ff92SAndroid Build Coastguard Worker                     ],
91*89c4ff92SAndroid Build Coastguard Worker                 } ],
92*89c4ff92SAndroid Build Coastguard Worker                 "buffers" : [
93*89c4ff92SAndroid Build Coastguard Worker                     { },
94*89c4ff92SAndroid Build Coastguard Worker                     { "data": )" + beginData + R"(, },
95*89c4ff92SAndroid Build Coastguard Worker                     { "data": )" + endData + R"(, },
96*89c4ff92SAndroid Build Coastguard Worker                     { "data": )" + stridesData + R"(, },
97*89c4ff92SAndroid Build Coastguard Worker                     { }
98*89c4ff92SAndroid Build Coastguard Worker                 ]
99*89c4ff92SAndroid Build Coastguard Worker             }
100*89c4ff92SAndroid Build Coastguard Worker         )";
101*89c4ff92SAndroid Build Coastguard Worker         Setup();
102*89c4ff92SAndroid Build Coastguard Worker     }
103*89c4ff92SAndroid Build Coastguard Worker };
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker struct StridedSlice4DFixture : StridedSliceFixture
106*89c4ff92SAndroid Build Coastguard Worker {
StridedSlice4DFixtureStridedSlice4DFixture107*89c4ff92SAndroid Build Coastguard Worker     StridedSlice4DFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",  // inputShape
108*89c4ff92SAndroid Build Coastguard Worker                                                   "[ 1, 2, 3, 1 ]",  // outputShape
109*89c4ff92SAndroid Build Coastguard Worker                                                   "[ 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]",  // beginData
110*89c4ff92SAndroid Build Coastguard Worker                                                   "[ 2,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]",  // endData
111*89c4ff92SAndroid Build Coastguard Worker                                                   "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]"   // stridesData
112*89c4ff92SAndroid Build Coastguard Worker                                                  ) {}
113*89c4ff92SAndroid Build Coastguard Worker };
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(StridedSlice4DFixture, "StridedSlice4D")
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker   RunTest<4, armnn::DataType::Float32>(
118*89c4ff92SAndroid Build Coastguard Worker       0,
119*89c4ff92SAndroid Build Coastguard Worker       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker       {{"outputTensor", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
126*89c4ff92SAndroid Build Coastguard Worker }
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker struct StridedSlice4DReverseFixture : StridedSliceFixture
129*89c4ff92SAndroid Build Coastguard Worker {
StridedSlice4DReverseFixtureStridedSlice4DReverseFixture130*89c4ff92SAndroid Build Coastguard Worker     StridedSlice4DReverseFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",    // inputShape
131*89c4ff92SAndroid Build Coastguard Worker                                                          "[ 1, 2, 3, 1 ]",    // outputShape
132*89c4ff92SAndroid Build Coastguard Worker                                                          "[ 1,0,0,0, "
133*89c4ff92SAndroid Build Coastguard Worker                                                          "255,255,255,255, "
134*89c4ff92SAndroid Build Coastguard Worker                                                          "0,0,0,0, "
135*89c4ff92SAndroid Build Coastguard Worker                                                          "0,0,0,0 ]",  // beginData    [ 1 -1 0 0 ]
136*89c4ff92SAndroid Build Coastguard Worker                                                          "[ 2,0,0,0, "
137*89c4ff92SAndroid Build Coastguard Worker                                                          "253,255,255,255, "
138*89c4ff92SAndroid Build Coastguard Worker                                                          "3,0,0,0, "
139*89c4ff92SAndroid Build Coastguard Worker                                                          "1,0,0,0 ]",  // endData      [ 2 -3 3 1 ]
140*89c4ff92SAndroid Build Coastguard Worker                                                          "[ 1,0,0,0, "
141*89c4ff92SAndroid Build Coastguard Worker                                                          "255,255,255,255, "
142*89c4ff92SAndroid Build Coastguard Worker                                                          "1,0,0,0, "
143*89c4ff92SAndroid Build Coastguard Worker                                                          "1,0,0,0 ]"   // stridesData  [ 1 -1 1 1 ]
144*89c4ff92SAndroid Build Coastguard Worker                                                         ) {}
145*89c4ff92SAndroid Build Coastguard Worker };
146*89c4ff92SAndroid Build Coastguard Worker 
147*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(StridedSlice4DReverseFixture, "StridedSlice4DReverse")
148*89c4ff92SAndroid Build Coastguard Worker {
149*89c4ff92SAndroid Build Coastguard Worker   RunTest<4, armnn::DataType::Float32>(
150*89c4ff92SAndroid Build Coastguard Worker       0,
151*89c4ff92SAndroid Build Coastguard Worker       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker       {{"outputTensor", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker 
160*89c4ff92SAndroid Build Coastguard Worker struct StridedSliceSimpleStrideFixture : StridedSliceFixture
161*89c4ff92SAndroid Build Coastguard Worker {
StridedSliceSimpleStrideFixtureStridedSliceSimpleStrideFixture162*89c4ff92SAndroid Build Coastguard Worker     StridedSliceSimpleStrideFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",  // inputShape
163*89c4ff92SAndroid Build Coastguard Worker                                                             "[ 2, 1, 2, 1 ]",  // outputShape
164*89c4ff92SAndroid Build Coastguard Worker                                                             "[ 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]",  // beginData
165*89c4ff92SAndroid Build Coastguard Worker                                                             "[ 3,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]",  // endData
166*89c4ff92SAndroid Build Coastguard Worker                                                             "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 1,0,0,0 ]"   // stridesData
167*89c4ff92SAndroid Build Coastguard Worker                                                  ) {}
168*89c4ff92SAndroid Build Coastguard Worker };
169*89c4ff92SAndroid Build Coastguard Worker 
170*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(StridedSliceSimpleStrideFixture, "StridedSliceSimpleStride")
171*89c4ff92SAndroid Build Coastguard Worker {
172*89c4ff92SAndroid Build Coastguard Worker   RunTest<4, armnn::DataType::Float32>(
173*89c4ff92SAndroid Build Coastguard Worker       0,
174*89c4ff92SAndroid Build Coastguard Worker       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
175*89c4ff92SAndroid Build Coastguard Worker 
176*89c4ff92SAndroid Build Coastguard Worker                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker       {{"outputTensor", { 1.0f, 1.0f,
181*89c4ff92SAndroid Build Coastguard Worker 
182*89c4ff92SAndroid Build Coastguard Worker                           5.0f, 5.0f }}});
183*89c4ff92SAndroid Build Coastguard Worker }
184*89c4ff92SAndroid Build Coastguard Worker 
185*89c4ff92SAndroid Build Coastguard Worker struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
186*89c4ff92SAndroid Build Coastguard Worker {
StridedSliceSimpleRangeMaskFixtureStridedSliceSimpleRangeMaskFixture187*89c4ff92SAndroid Build Coastguard Worker     StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",  // inputShape
188*89c4ff92SAndroid Build Coastguard Worker                                                                "[ 3, 2, 3, 1 ]",  // outputShape
189*89c4ff92SAndroid Build Coastguard Worker                                                                "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",  // beginData
190*89c4ff92SAndroid Build Coastguard Worker                                                                "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",  // endData
191*89c4ff92SAndroid Build Coastguard Worker                                                                "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",  // stridesData
192*89c4ff92SAndroid Build Coastguard Worker                                                                (1 << 4) - 1,  // beginMask
193*89c4ff92SAndroid Build Coastguard Worker                                                                (1 << 4) - 1   // endMask
194*89c4ff92SAndroid Build Coastguard Worker                                                  ) {}
195*89c4ff92SAndroid Build Coastguard Worker };
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(StridedSliceSimpleRangeMaskFixture, "StridedSliceSimpleRangeMask")
198*89c4ff92SAndroid Build Coastguard Worker {
199*89c4ff92SAndroid Build Coastguard Worker   RunTest<4, armnn::DataType::Float32>(
200*89c4ff92SAndroid Build Coastguard Worker       0,
201*89c4ff92SAndroid Build Coastguard Worker       {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
202*89c4ff92SAndroid Build Coastguard Worker 
203*89c4ff92SAndroid Build Coastguard Worker                          3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
204*89c4ff92SAndroid Build Coastguard Worker 
205*89c4ff92SAndroid Build Coastguard Worker                          5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
206*89c4ff92SAndroid Build Coastguard Worker 
207*89c4ff92SAndroid Build Coastguard Worker       {{"outputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker                           3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker                           5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
212*89c4ff92SAndroid Build Coastguard Worker }
213*89c4ff92SAndroid Build Coastguard Worker 
214*89c4ff92SAndroid Build Coastguard Worker }
215