xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Slice.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 TEST_SUITE("TensorflowLiteParser_Slice")
9 {
10 struct SliceFixture : public ParserFlatbuffersFixture
11 {
SliceFixtureSliceFixture12     explicit SliceFixture(const std::string & inputShape,
13                           const std::string & outputShape,
14                           const std::string & beginData,
15                           const std::string & sizeData)
16     {
17         m_JsonString = R"(
18             {
19                   "version": 3,
20                   "operator_codes": [
21                     {
22                       "builtin_code": "SLICE",
23                       "version": 1
24                     }
25                   ],
26                   "subgraphs": [
27                     {
28                       "tensors": [
29                         {
30                           "shape": )" + inputShape + R"(,
31                           "type": "FLOAT32",
32                           "buffer": 0,
33                           "name": "inputTensor",
34                           "quantization": {
35                             "min": [
36                               0.0
37                             ],
38                             "max": [
39                               255.0
40                             ],
41                             "details_type": 0,
42                             "quantized_dimension": 0
43                           },
44                           "is_variable": false
45                         },
46                         {
47                           "shape": )" + outputShape + R"(,
48                           "type": "FLOAT32",
49                           "buffer": 1,
50                           "name": "outputTensor",
51                           "quantization": {
52                             "details_type": 0,
53                             "quantized_dimension": 0
54                           },
55                           "is_variable": false
56                         })";
57         m_JsonString += R"(,
58                             {
59                             "shape": [
60                                 3
61                             ],
62                             "type": "INT32",
63                             "buffer": 2,
64                             "name": "beginTensor",
65                             "quantization": {
66                             }
67                             })";
68         m_JsonString += R"(,
69                             {
70                             "shape": [
71                                 3
72                             ],
73                             "type": "INT32",
74                             "buffer": 3,
75                             "name": "sizeTensor",
76                             "quantization": {
77                             }
78                             })";
79         m_JsonString += R"(],
80                       "inputs": [
81                         0
82                       ],
83                       "outputs": [
84                         1
85                       ],
86                       "operators": [
87                         {
88                           "opcode_index": 0,
89                           "inputs": [
90                             0,
91                             2,
92                             3)";
93         m_JsonString += R"(],
94                           "outputs": [
95                             1
96                           ],
97                           mutating_variable_inputs: [
98                           ]
99                         }
100                       ]
101                     }
102                   ],
103                   "description": "TOCO Converted.",
104                   "buffers": [
105                     { },
106                     { })";
107         m_JsonString += R"(,{"data": )" + beginData + R"( })";
108         m_JsonString += R"(,{"data": )" + sizeData + R"( })";
109         m_JsonString += R"(
110                   ]
111                 }
112         )";
113         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
114     }
115 };
116 
117 struct SliceFixtureSingleDim : SliceFixture
118 {
SliceFixtureSingleDimSliceFixtureSingleDim119     SliceFixtureSingleDim() : SliceFixture("[ 3, 2, 3 ]",
120                                            "[ 1, 1, 3 ]",
121                                            "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]",
122                                            "[ 1, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0 ]") {}
123 };
124 
125 TEST_CASE_FIXTURE(SliceFixtureSingleDim, "SliceSingleDim")
126 {
127     RunTest<3, armnn::DataType::Float32>(
128       0,
129       {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}},
130       {{"outputTensor", { 3, 3, 3 }}});
131 
132     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
133                 == armnn::TensorShape({1,1,3})));
134 }
135 
136 struct SliceFixtureD123 : SliceFixture
137 {
SliceFixtureD123SliceFixtureD123138     SliceFixtureD123() : SliceFixture("[ 3, 2, 3 ]",
139                                       "[ 1, 2, 3 ]",
140                                       "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]",
141                                       "[ 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0 ]") {}
142 };
143 
144 TEST_CASE_FIXTURE(SliceFixtureD123, "SliceD123")
145 {
146     RunTest<3, armnn::DataType::Float32>(
147         0,
148         {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}},
149         {{"outputTensor", { 3, 3, 3, 4, 4, 4 }}});
150 
151     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
152                 == armnn::TensorShape({1,2,3})));
153 }
154 
155 struct SliceFixtureD213 : SliceFixture
156 {
SliceFixtureD213SliceFixtureD213157     SliceFixtureD213() : SliceFixture("[ 3, 2, 3 ]",
158                                       "[ 2, 1, 3 ]",
159                                       "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]",
160                                       "[ 2, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0 ]") {}
161 };
162 
163 TEST_CASE_FIXTURE(SliceFixtureD213, "SliceD213")
164 {
165     RunTest<3, armnn::DataType::Float32>(
166         0,
167         {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}},
168         {{"outputTensor", { 3, 3, 3, 5, 5, 5 }}});
169 
170     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
171                 == armnn::TensorShape({2,1,3})));
172 }
173 
174 struct DynamicSliceFixtureD213 : SliceFixture
175 {
DynamicSliceFixtureD213DynamicSliceFixtureD213176     DynamicSliceFixtureD213() : SliceFixture("[ 3, 2, 3 ]",
177                                              "[ ]",
178                                              "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]",
179                                              "[ 255, 255, 255, 255, 1, 0, 0, 0, 255, 255, 255, 255 ]") {}
180 };
181 
182 TEST_CASE_FIXTURE(DynamicSliceFixtureD213, "DynamicSliceD213")
183 {
184     RunTest<3, armnn::DataType::Float32, armnn::DataType::Float32>(
185         0,
186         {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}},
187         {{"outputTensor", { 3, 3, 3, 5, 5, 5 }}},
188         true);
189 }
190 }