xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Slice.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_Slice")
9*89c4ff92SAndroid Build Coastguard Worker {
10*89c4ff92SAndroid Build Coastguard Worker struct SliceFixture : public ParserFlatbuffersFixture
11*89c4ff92SAndroid Build Coastguard Worker {
SliceFixtureSliceFixture12*89c4ff92SAndroid Build Coastguard Worker     explicit SliceFixture(const std::string & inputShape,
13*89c4ff92SAndroid Build Coastguard Worker                           const std::string & outputShape,
14*89c4ff92SAndroid Build Coastguard Worker                           const std::string & beginData,
15*89c4ff92SAndroid Build Coastguard Worker                           const std::string & sizeData)
16*89c4ff92SAndroid Build Coastguard Worker     {
17*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
18*89c4ff92SAndroid Build Coastguard Worker             {
19*89c4ff92SAndroid Build Coastguard Worker                   "version": 3,
20*89c4ff92SAndroid Build Coastguard Worker                   "operator_codes": [
21*89c4ff92SAndroid Build Coastguard Worker                     {
22*89c4ff92SAndroid Build Coastguard Worker                       "builtin_code": "SLICE",
23*89c4ff92SAndroid Build Coastguard Worker                       "version": 1
24*89c4ff92SAndroid Build Coastguard Worker                     }
25*89c4ff92SAndroid Build Coastguard Worker                   ],
26*89c4ff92SAndroid Build Coastguard Worker                   "subgraphs": [
27*89c4ff92SAndroid Build Coastguard Worker                     {
28*89c4ff92SAndroid Build Coastguard Worker                       "tensors": [
29*89c4ff92SAndroid Build Coastguard Worker                         {
30*89c4ff92SAndroid Build Coastguard Worker                           "shape": )" + inputShape + R"(,
31*89c4ff92SAndroid Build Coastguard Worker                           "type": "FLOAT32",
32*89c4ff92SAndroid Build Coastguard Worker                           "buffer": 0,
33*89c4ff92SAndroid Build Coastguard Worker                           "name": "inputTensor",
34*89c4ff92SAndroid Build Coastguard Worker                           "quantization": {
35*89c4ff92SAndroid Build Coastguard Worker                             "min": [
36*89c4ff92SAndroid Build Coastguard Worker                               0.0
37*89c4ff92SAndroid Build Coastguard Worker                             ],
38*89c4ff92SAndroid Build Coastguard Worker                             "max": [
39*89c4ff92SAndroid Build Coastguard Worker                               255.0
40*89c4ff92SAndroid Build Coastguard Worker                             ],
41*89c4ff92SAndroid Build Coastguard Worker                             "details_type": 0,
42*89c4ff92SAndroid Build Coastguard Worker                             "quantized_dimension": 0
43*89c4ff92SAndroid Build Coastguard Worker                           },
44*89c4ff92SAndroid Build Coastguard Worker                           "is_variable": false
45*89c4ff92SAndroid Build Coastguard Worker                         },
46*89c4ff92SAndroid Build Coastguard Worker                         {
47*89c4ff92SAndroid Build Coastguard Worker                           "shape": )" + outputShape + R"(,
48*89c4ff92SAndroid Build Coastguard Worker                           "type": "FLOAT32",
49*89c4ff92SAndroid Build Coastguard Worker                           "buffer": 1,
50*89c4ff92SAndroid Build Coastguard Worker                           "name": "outputTensor",
51*89c4ff92SAndroid Build Coastguard Worker                           "quantization": {
52*89c4ff92SAndroid Build Coastguard Worker                             "details_type": 0,
53*89c4ff92SAndroid Build Coastguard Worker                             "quantized_dimension": 0
54*89c4ff92SAndroid Build Coastguard Worker                           },
55*89c4ff92SAndroid Build Coastguard Worker                           "is_variable": false
56*89c4ff92SAndroid Build Coastguard Worker                         })";
57*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(,
58*89c4ff92SAndroid Build Coastguard Worker                             {
59*89c4ff92SAndroid Build Coastguard Worker                             "shape": [
60*89c4ff92SAndroid Build Coastguard Worker                                 3
61*89c4ff92SAndroid Build Coastguard Worker                             ],
62*89c4ff92SAndroid Build Coastguard Worker                             "type": "INT32",
63*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 2,
64*89c4ff92SAndroid Build Coastguard Worker                             "name": "beginTensor",
65*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
66*89c4ff92SAndroid Build Coastguard Worker                             }
67*89c4ff92SAndroid Build Coastguard Worker                             })";
68*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(,
69*89c4ff92SAndroid Build Coastguard Worker                             {
70*89c4ff92SAndroid Build Coastguard Worker                             "shape": [
71*89c4ff92SAndroid Build Coastguard Worker                                 3
72*89c4ff92SAndroid Build Coastguard Worker                             ],
73*89c4ff92SAndroid Build Coastguard Worker                             "type": "INT32",
74*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 3,
75*89c4ff92SAndroid Build Coastguard Worker                             "name": "sizeTensor",
76*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
77*89c4ff92SAndroid Build Coastguard Worker                             }
78*89c4ff92SAndroid Build Coastguard Worker                             })";
79*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(],
80*89c4ff92SAndroid Build Coastguard Worker                       "inputs": [
81*89c4ff92SAndroid Build Coastguard Worker                         0
82*89c4ff92SAndroid Build Coastguard Worker                       ],
83*89c4ff92SAndroid Build Coastguard Worker                       "outputs": [
84*89c4ff92SAndroid Build Coastguard Worker                         1
85*89c4ff92SAndroid Build Coastguard Worker                       ],
86*89c4ff92SAndroid Build Coastguard Worker                       "operators": [
87*89c4ff92SAndroid Build Coastguard Worker                         {
88*89c4ff92SAndroid Build Coastguard Worker                           "opcode_index": 0,
89*89c4ff92SAndroid Build Coastguard Worker                           "inputs": [
90*89c4ff92SAndroid Build Coastguard Worker                             0,
91*89c4ff92SAndroid Build Coastguard Worker                             2,
92*89c4ff92SAndroid Build Coastguard Worker                             3)";
93*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(],
94*89c4ff92SAndroid Build Coastguard Worker                           "outputs": [
95*89c4ff92SAndroid Build Coastguard Worker                             1
96*89c4ff92SAndroid Build Coastguard Worker                           ],
97*89c4ff92SAndroid Build Coastguard Worker                           mutating_variable_inputs: [
98*89c4ff92SAndroid Build Coastguard Worker                           ]
99*89c4ff92SAndroid Build Coastguard Worker                         }
100*89c4ff92SAndroid Build Coastguard Worker                       ]
101*89c4ff92SAndroid Build Coastguard Worker                     }
102*89c4ff92SAndroid Build Coastguard Worker                   ],
103*89c4ff92SAndroid Build Coastguard Worker                   "description": "TOCO Converted.",
104*89c4ff92SAndroid Build Coastguard Worker                   "buffers": [
105*89c4ff92SAndroid Build Coastguard Worker                     { },
106*89c4ff92SAndroid Build Coastguard Worker                     { })";
107*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(,{"data": )" + beginData + R"( })";
108*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(,{"data": )" + sizeData + R"( })";
109*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(
110*89c4ff92SAndroid Build Coastguard Worker                   ]
111*89c4ff92SAndroid Build Coastguard Worker                 }
112*89c4ff92SAndroid Build Coastguard Worker         )";
113*89c4ff92SAndroid Build Coastguard Worker         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
114*89c4ff92SAndroid Build Coastguard Worker     }
115*89c4ff92SAndroid Build Coastguard Worker };
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker struct SliceFixtureSingleDim : SliceFixture
118*89c4ff92SAndroid Build Coastguard Worker {
SliceFixtureSingleDimSliceFixtureSingleDim119*89c4ff92SAndroid Build Coastguard Worker     SliceFixtureSingleDim() : SliceFixture("[ 3, 2, 3 ]",
120*89c4ff92SAndroid Build Coastguard Worker                                            "[ 1, 1, 3 ]",
121*89c4ff92SAndroid Build Coastguard Worker                                            "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]",
122*89c4ff92SAndroid Build Coastguard Worker                                            "[ 1, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0 ]") {}
123*89c4ff92SAndroid Build Coastguard Worker };
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SliceFixtureSingleDim, "SliceSingleDim")
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::Float32>(
128*89c4ff92SAndroid Build Coastguard Worker       0,
129*89c4ff92SAndroid Build Coastguard Worker       {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}},
130*89c4ff92SAndroid Build Coastguard Worker       {{"outputTensor", { 3, 3, 3 }}});
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
133*89c4ff92SAndroid Build Coastguard Worker                 == armnn::TensorShape({1,1,3})));
134*89c4ff92SAndroid Build Coastguard Worker }
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker struct SliceFixtureD123 : SliceFixture
137*89c4ff92SAndroid Build Coastguard Worker {
SliceFixtureD123SliceFixtureD123138*89c4ff92SAndroid Build Coastguard Worker     SliceFixtureD123() : SliceFixture("[ 3, 2, 3 ]",
139*89c4ff92SAndroid Build Coastguard Worker                                       "[ 1, 2, 3 ]",
140*89c4ff92SAndroid Build Coastguard Worker                                       "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]",
141*89c4ff92SAndroid Build Coastguard Worker                                       "[ 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0 ]") {}
142*89c4ff92SAndroid Build Coastguard Worker };
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SliceFixtureD123, "SliceD123")
145*89c4ff92SAndroid Build Coastguard Worker {
146*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::Float32>(
147*89c4ff92SAndroid Build Coastguard Worker         0,
148*89c4ff92SAndroid Build Coastguard Worker         {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}},
149*89c4ff92SAndroid Build Coastguard Worker         {{"outputTensor", { 3, 3, 3, 4, 4, 4 }}});
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
152*89c4ff92SAndroid Build Coastguard Worker                 == armnn::TensorShape({1,2,3})));
153*89c4ff92SAndroid Build Coastguard Worker }
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker struct SliceFixtureD213 : SliceFixture
156*89c4ff92SAndroid Build Coastguard Worker {
SliceFixtureD213SliceFixtureD213157*89c4ff92SAndroid Build Coastguard Worker     SliceFixtureD213() : SliceFixture("[ 3, 2, 3 ]",
158*89c4ff92SAndroid Build Coastguard Worker                                       "[ 2, 1, 3 ]",
159*89c4ff92SAndroid Build Coastguard Worker                                       "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]",
160*89c4ff92SAndroid Build Coastguard Worker                                       "[ 2, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0 ]") {}
161*89c4ff92SAndroid Build Coastguard Worker };
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SliceFixtureD213, "SliceD213")
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::Float32>(
166*89c4ff92SAndroid Build Coastguard Worker         0,
167*89c4ff92SAndroid Build Coastguard Worker         {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}},
168*89c4ff92SAndroid Build Coastguard Worker         {{"outputTensor", { 3, 3, 3, 5, 5, 5 }}});
169*89c4ff92SAndroid Build Coastguard Worker 
170*89c4ff92SAndroid Build Coastguard Worker     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
171*89c4ff92SAndroid Build Coastguard Worker                 == armnn::TensorShape({2,1,3})));
172*89c4ff92SAndroid Build Coastguard Worker }
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker struct DynamicSliceFixtureD213 : SliceFixture
175*89c4ff92SAndroid Build Coastguard Worker {
DynamicSliceFixtureD213DynamicSliceFixtureD213176*89c4ff92SAndroid Build Coastguard Worker     DynamicSliceFixtureD213() : SliceFixture("[ 3, 2, 3 ]",
177*89c4ff92SAndroid Build Coastguard Worker                                              "[ ]",
178*89c4ff92SAndroid Build Coastguard Worker                                              "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]",
179*89c4ff92SAndroid Build Coastguard Worker                                              "[ 255, 255, 255, 255, 1, 0, 0, 0, 255, 255, 255, 255 ]") {}
180*89c4ff92SAndroid Build Coastguard Worker };
181*89c4ff92SAndroid Build Coastguard Worker 
182*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DynamicSliceFixtureD213, "DynamicSliceD213")
183*89c4ff92SAndroid Build Coastguard Worker {
184*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::Float32, armnn::DataType::Float32>(
185*89c4ff92SAndroid Build Coastguard Worker         0,
186*89c4ff92SAndroid Build Coastguard Worker         {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}},
187*89c4ff92SAndroid Build Coastguard Worker         {{"outputTensor", { 3, 3, 3, 5, 5, 5 }}},
188*89c4ff92SAndroid Build Coastguard Worker         true);
189*89c4ff92SAndroid Build Coastguard Worker }
190*89c4ff92SAndroid Build Coastguard Worker }