xref: /aosp_15_r20/external/armnn/delegate/test/StridedSliceTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "StridedSliceTestHelper.hpp"
7 
8 #include <armnn_delegate.hpp>
9 
10 #include <flatbuffers/flatbuffers.h>
11 
12 #include <doctest/doctest.h>
13 
14 namespace armnnDelegate
15 {
16 
StridedSlice4DTest(std::vector<armnn::BackendId> & backends)17 void StridedSlice4DTest(std::vector<armnn::BackendId>& backends)
18 {
19     std::vector<int32_t> inputShape  { 3, 2, 3, 1 };
20     std::vector<int32_t> outputShape { 1, 2, 3, 1 };
21     std::vector<int32_t> beginShape  { 4 };
22     std::vector<int32_t> endShape    { 4 };
23     std::vector<int32_t> strideShape { 4 };
24 
25     std::vector<int32_t> beginData  { 1, 0, 0, 0 };
26     std::vector<int32_t> endData    { 2, 2, 3, 1 };
27     std::vector<int32_t> strideData { 1, 1, 1, 1 };
28     std::vector<float> inputData  { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
29                                     3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
30                                     5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
31     std::vector<float> outputData { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f };
32 
33     StridedSliceTestImpl<float>(
34             backends,
35             inputData,
36             outputData,
37             beginData,
38             endData,
39             strideData,
40             inputShape,
41             beginShape,
42             endShape,
43             strideShape,
44             outputShape
45             );
46 }
47 
StridedSlice4DReverseTest(std::vector<armnn::BackendId> & backends)48 void StridedSlice4DReverseTest(std::vector<armnn::BackendId>& backends)
49 {
50     std::vector<int32_t> inputShape  { 3, 2, 3, 1 };
51     std::vector<int32_t> outputShape { 1, 2, 3, 1 };
52     std::vector<int32_t> beginShape  { 4 };
53     std::vector<int32_t> endShape    { 4 };
54     std::vector<int32_t> strideShape { 4 };
55 
56     std::vector<int32_t> beginData  { 1, -1, 0, 0 };
57     std::vector<int32_t> endData    { 2, -3, 3, 1 };
58     std::vector<int32_t> strideData { 1, -1, 1, 1 };
59     std::vector<float>   inputData  { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
60                                       3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
61                                       5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
62     std::vector<float>   outputData { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f };
63 
64     StridedSliceTestImpl<float>(
65             backends,
66             inputData,
67             outputData,
68             beginData,
69             endData,
70             strideData,
71             inputShape,
72             beginShape,
73             endShape,
74             strideShape,
75             outputShape
76     );
77 }
78 
StridedSliceSimpleStrideTest(std::vector<armnn::BackendId> & backends)79 void StridedSliceSimpleStrideTest(std::vector<armnn::BackendId>& backends)
80 {
81     std::vector<int32_t> inputShape  { 3, 2, 3, 1 };
82     std::vector<int32_t> outputShape { 2, 1, 2, 1 };
83     std::vector<int32_t> beginShape  { 4 };
84     std::vector<int32_t> endShape    { 4 };
85     std::vector<int32_t> strideShape { 4 };
86 
87     std::vector<int32_t> beginData  { 0, 0, 0, 0 };
88     std::vector<int32_t> endData    { 3, 2, 3, 1 };
89     std::vector<int32_t> strideData { 2, 2, 2, 1 };
90     std::vector<float>   inputData  { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
91                                       3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
92                                       5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
93     std::vector<float>   outputData { 1.0f, 1.0f,
94                                       5.0f, 5.0f };
95 
96     StridedSliceTestImpl<float>(
97             backends,
98             inputData,
99             outputData,
100             beginData,
101             endData,
102             strideData,
103             inputShape,
104             beginShape,
105             endShape,
106             strideShape,
107             outputShape
108     );
109 }
110 
StridedSliceSimpleRangeMaskTest(std::vector<armnn::BackendId> & backends)111 void StridedSliceSimpleRangeMaskTest(std::vector<armnn::BackendId>& backends)
112 {
113     std::vector<int32_t> inputShape  { 3, 2, 3, 1 };
114     std::vector<int32_t> outputShape { 3, 2, 3, 1 };
115     std::vector<int32_t> beginShape  { 4 };
116     std::vector<int32_t> endShape    { 4 };
117     std::vector<int32_t> strideShape { 4 };
118 
119     std::vector<int32_t> beginData  { 1, 1, 1, 1 };
120     std::vector<int32_t> endData    { 1, 1, 1, 1 };
121     std::vector<int32_t> strideData { 1, 1, 1, 1 };
122 
123     int beginMask = -1;
124     int endMask   = -1;
125 
126     std::vector<float>   inputData  { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
127                                       3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
128                                       5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
129     std::vector<float>   outputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
130                                       3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
131                                       5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
132 
133     StridedSliceTestImpl<float>(
134             backends,
135             inputData,
136             outputData,
137             beginData,
138             endData,
139             strideData,
140             inputShape,
141             beginShape,
142             endShape,
143             strideShape,
144             outputShape,
145             beginMask,
146             endMask
147     );
148 }
149 
150 TEST_SUITE("StridedSlice_CpuRefTests")
151 {
152 
153 TEST_CASE ("StridedSlice_4D_CpuRef_Test")
154 {
155     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
156     StridedSlice4DTest(backends);
157 }
158 
159 TEST_CASE ("StridedSlice_4D_Reverse_CpuRef_Test")
160 {
161     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
162     StridedSlice4DReverseTest(backends);
163 }
164 
165 TEST_CASE ("StridedSlice_SimpleStride_CpuRef_Test")
166 {
167     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
168     StridedSliceSimpleStrideTest(backends);
169 }
170 
171 TEST_CASE ("StridedSlice_SimpleRange_CpuRef_Test")
172 {
173     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
174     StridedSliceSimpleRangeMaskTest(backends);
175 }
176 
177 } // StridedSlice_CpuRefTests TestSuite
178 
179 
180 
181 TEST_SUITE("StridedSlice_CpuAccTests")
182 {
183 
184 TEST_CASE ("StridedSlice_4D_CpuAcc_Test")
185 {
186     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
187     StridedSlice4DTest(backends);
188 }
189 
190 TEST_CASE ("StridedSlice_4D_Reverse_CpuAcc_Test")
191 {
192     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
193     StridedSlice4DReverseTest(backends);
194 }
195 
196 TEST_CASE ("StridedSlice_SimpleStride_CpuAcc_Test")
197 {
198     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
199     StridedSliceSimpleStrideTest(backends);
200 }
201 
202 TEST_CASE ("StridedSlice_SimpleRange_CpuAcc_Test")
203 {
204     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
205     StridedSliceSimpleRangeMaskTest(backends);
206 }
207 
208 } // StridedSlice_CpuAccTests TestSuite
209 
210 
211 
212 TEST_SUITE("StridedSlice_GpuAccTests")
213 {
214 
215 TEST_CASE ("StridedSlice_4D_GpuAcc_Test")
216 {
217     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
218     StridedSlice4DTest(backends);
219 }
220 
221 TEST_CASE ("StridedSlice_4D_Reverse_GpuAcc_Test")
222 {
223     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
224     StridedSlice4DReverseTest(backends);
225 }
226 
227 TEST_CASE ("StridedSlice_SimpleStride_GpuAcc_Test")
228 {
229     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
230     StridedSliceSimpleStrideTest(backends);
231 }
232 
233 TEST_CASE ("StridedSlice_SimpleRange_GpuAcc_Test")
234 {
235     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
236     StridedSliceSimpleRangeMaskTest(backends);
237 }
238 
239 } // StridedSlice_GpuAccTests TestSuite
240 
241 } // namespace armnnDelegate