xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeStridedSlice.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <string>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_StridedSlice")
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker struct StridedSliceFixture : public ParserFlatbuffersSerializeFixture
14*89c4ff92SAndroid Build Coastguard Worker {
StridedSliceFixtureStridedSliceFixture15*89c4ff92SAndroid Build Coastguard Worker     explicit StridedSliceFixture(const std::string& inputShape,
16*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& begin,
17*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& end,
18*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& stride,
19*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& beginMask,
20*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& endMask,
21*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& shrinkAxisMask,
22*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& ellipsisMask,
23*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& newAxisMask,
24*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& dataLayout,
25*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& outputShape,
26*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& dataType)
27*89c4ff92SAndroid Build Coastguard Worker     {
28*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
29*89c4ff92SAndroid Build Coastguard Worker             {
30*89c4ff92SAndroid Build Coastguard Worker                 inputIds: [0],
31*89c4ff92SAndroid Build Coastguard Worker                 outputIds: [2],
32*89c4ff92SAndroid Build Coastguard Worker                 layers: [
33*89c4ff92SAndroid Build Coastguard Worker                     {
34*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "InputLayer",
35*89c4ff92SAndroid Build Coastguard Worker                         layer: {
36*89c4ff92SAndroid Build Coastguard Worker                             base: {
37*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId: 0,
38*89c4ff92SAndroid Build Coastguard Worker                                 base: {
39*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
40*89c4ff92SAndroid Build Coastguard Worker                                     layerName: "InputLayer",
41*89c4ff92SAndroid Build Coastguard Worker                                     layerType: "Input",
42*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots: [{
43*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
44*89c4ff92SAndroid Build Coastguard Worker                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
45*89c4ff92SAndroid Build Coastguard Worker                                     }],
46*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots: [{
47*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
48*89c4ff92SAndroid Build Coastguard Worker                                         tensorInfo: {
49*89c4ff92SAndroid Build Coastguard Worker                                             dimensions: )" + inputShape + R"(,
50*89c4ff92SAndroid Build Coastguard Worker                                             dataType: )" + dataType + R"(
51*89c4ff92SAndroid Build Coastguard Worker                                         }
52*89c4ff92SAndroid Build Coastguard Worker                                     }]
53*89c4ff92SAndroid Build Coastguard Worker                                 }
54*89c4ff92SAndroid Build Coastguard Worker                             }
55*89c4ff92SAndroid Build Coastguard Worker                         }
56*89c4ff92SAndroid Build Coastguard Worker                     },
57*89c4ff92SAndroid Build Coastguard Worker                     {
58*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "StridedSliceLayer",
59*89c4ff92SAndroid Build Coastguard Worker                         layer: {
60*89c4ff92SAndroid Build Coastguard Worker                             base: {
61*89c4ff92SAndroid Build Coastguard Worker                                 index: 1,
62*89c4ff92SAndroid Build Coastguard Worker                                 layerName: "StridedSliceLayer",
63*89c4ff92SAndroid Build Coastguard Worker                                 layerType: "StridedSlice",
64*89c4ff92SAndroid Build Coastguard Worker                                 inputSlots: [{
65*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
66*89c4ff92SAndroid Build Coastguard Worker                                     connection: {sourceLayerIndex:0, outputSlotIndex:0 },
67*89c4ff92SAndroid Build Coastguard Worker                                 }],
68*89c4ff92SAndroid Build Coastguard Worker                                 outputSlots: [{
69*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
70*89c4ff92SAndroid Build Coastguard Worker                                     tensorInfo: {
71*89c4ff92SAndroid Build Coastguard Worker                                         dimensions: )" + outputShape + R"(,
72*89c4ff92SAndroid Build Coastguard Worker                                         dataType: )" + dataType + R"(
73*89c4ff92SAndroid Build Coastguard Worker                                     }
74*89c4ff92SAndroid Build Coastguard Worker                                 }]
75*89c4ff92SAndroid Build Coastguard Worker                             },
76*89c4ff92SAndroid Build Coastguard Worker                             descriptor: {
77*89c4ff92SAndroid Build Coastguard Worker                                 begin: )" + begin + R"(,
78*89c4ff92SAndroid Build Coastguard Worker                                 end: )" + end + R"(,
79*89c4ff92SAndroid Build Coastguard Worker                                 stride: )" + stride + R"(,
80*89c4ff92SAndroid Build Coastguard Worker                                 beginMask: )" + beginMask + R"(,
81*89c4ff92SAndroid Build Coastguard Worker                                 endMask: )" + endMask + R"(,
82*89c4ff92SAndroid Build Coastguard Worker                                 shrinkAxisMask: )" + shrinkAxisMask + R"(,
83*89c4ff92SAndroid Build Coastguard Worker                                 ellipsisMask: )" + ellipsisMask + R"(,
84*89c4ff92SAndroid Build Coastguard Worker                                 newAxisMask: )" + newAxisMask + R"(,
85*89c4ff92SAndroid Build Coastguard Worker                                 dataLayout: )" + dataLayout + R"(,
86*89c4ff92SAndroid Build Coastguard Worker                             }
87*89c4ff92SAndroid Build Coastguard Worker                         }
88*89c4ff92SAndroid Build Coastguard Worker                     },
89*89c4ff92SAndroid Build Coastguard Worker                     {
90*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "OutputLayer",
91*89c4ff92SAndroid Build Coastguard Worker                         layer: {
92*89c4ff92SAndroid Build Coastguard Worker                             base:{
93*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId: 2,
94*89c4ff92SAndroid Build Coastguard Worker                                 base: {
95*89c4ff92SAndroid Build Coastguard Worker                                     index: 2,
96*89c4ff92SAndroid Build Coastguard Worker                                     layerName: "OutputLayer",
97*89c4ff92SAndroid Build Coastguard Worker                                     layerType: "Output",
98*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots: [{
99*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
100*89c4ff92SAndroid Build Coastguard Worker                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
101*89c4ff92SAndroid Build Coastguard Worker                                     }],
102*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots: [{
103*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
104*89c4ff92SAndroid Build Coastguard Worker                                         tensorInfo: {
105*89c4ff92SAndroid Build Coastguard Worker                                             dimensions: )" + outputShape + R"(,
106*89c4ff92SAndroid Build Coastguard Worker                                             dataType: )" + dataType + R"(
107*89c4ff92SAndroid Build Coastguard Worker                                         },
108*89c4ff92SAndroid Build Coastguard Worker                                     }],
109*89c4ff92SAndroid Build Coastguard Worker                                 }
110*89c4ff92SAndroid Build Coastguard Worker                             }
111*89c4ff92SAndroid Build Coastguard Worker                         },
112*89c4ff92SAndroid Build Coastguard Worker                     }
113*89c4ff92SAndroid Build Coastguard Worker                 ]
114*89c4ff92SAndroid Build Coastguard Worker             }
115*89c4ff92SAndroid Build Coastguard Worker         )";
116*89c4ff92SAndroid Build Coastguard Worker         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
117*89c4ff92SAndroid Build Coastguard Worker     }
118*89c4ff92SAndroid Build Coastguard Worker };
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker struct SimpleStridedSliceFixture : StridedSliceFixture
121*89c4ff92SAndroid Build Coastguard Worker {
SimpleStridedSliceFixtureSimpleStridedSliceFixture122*89c4ff92SAndroid Build Coastguard Worker     SimpleStridedSliceFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
123*89c4ff92SAndroid Build Coastguard Worker                                                       "[ 0, 0, 0, 0 ]",
124*89c4ff92SAndroid Build Coastguard Worker                                                       "[ 3, 2, 3, 1 ]",
125*89c4ff92SAndroid Build Coastguard Worker                                                       "[ 2, 2, 2, 1 ]",
126*89c4ff92SAndroid Build Coastguard Worker                                                       "0",
127*89c4ff92SAndroid Build Coastguard Worker                                                       "0",
128*89c4ff92SAndroid Build Coastguard Worker                                                       "0",
129*89c4ff92SAndroid Build Coastguard Worker                                                       "0",
130*89c4ff92SAndroid Build Coastguard Worker                                                       "0",
131*89c4ff92SAndroid Build Coastguard Worker                                                       "NCHW",
132*89c4ff92SAndroid Build Coastguard Worker                                                       "[ 2, 1, 2, 1 ]",
133*89c4ff92SAndroid Build Coastguard Worker                                                       "Float32") {}
134*89c4ff92SAndroid Build Coastguard Worker };
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleStridedSliceFixture, "SimpleStridedSliceFloat32")
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(0,
139*89c4ff92SAndroid Build Coastguard Worker                                          {
140*89c4ff92SAndroid Build Coastguard Worker                                              1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
141*89c4ff92SAndroid Build Coastguard Worker                                              3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
142*89c4ff92SAndroid Build Coastguard Worker                                              5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
143*89c4ff92SAndroid Build Coastguard Worker                                          },
144*89c4ff92SAndroid Build Coastguard Worker                                          {
145*89c4ff92SAndroid Build Coastguard Worker                                              1.0f, 1.0f, 5.0f, 5.0f
146*89c4ff92SAndroid Build Coastguard Worker                                          });
147*89c4ff92SAndroid Build Coastguard Worker }
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker struct StridedSliceMaskFixture : StridedSliceFixture
150*89c4ff92SAndroid Build Coastguard Worker {
StridedSliceMaskFixtureStridedSliceMaskFixture151*89c4ff92SAndroid Build Coastguard Worker     StridedSliceMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
152*89c4ff92SAndroid Build Coastguard Worker                                                     "[ 1, 1, 1, 1 ]",
153*89c4ff92SAndroid Build Coastguard Worker                                                     "[ 1, 1, 1, 1 ]",
154*89c4ff92SAndroid Build Coastguard Worker                                                     "[ 1, 1, 1, 1 ]",
155*89c4ff92SAndroid Build Coastguard Worker                                                     "15",
156*89c4ff92SAndroid Build Coastguard Worker                                                     "15",
157*89c4ff92SAndroid Build Coastguard Worker                                                     "0",
158*89c4ff92SAndroid Build Coastguard Worker                                                     "0",
159*89c4ff92SAndroid Build Coastguard Worker                                                     "0",
160*89c4ff92SAndroid Build Coastguard Worker                                                     "NCHW",
161*89c4ff92SAndroid Build Coastguard Worker                                                     "[ 3, 2, 3, 1 ]",
162*89c4ff92SAndroid Build Coastguard Worker                                                     "Float32") {}
163*89c4ff92SAndroid Build Coastguard Worker };
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(StridedSliceMaskFixture, "StridedSliceMaskFloat32")
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(0,
168*89c4ff92SAndroid Build Coastguard Worker                                          {
169*89c4ff92SAndroid Build Coastguard Worker                                              1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
170*89c4ff92SAndroid Build Coastguard Worker                                              3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
171*89c4ff92SAndroid Build Coastguard Worker                                              5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
172*89c4ff92SAndroid Build Coastguard Worker                                          },
173*89c4ff92SAndroid Build Coastguard Worker                                          {
174*89c4ff92SAndroid Build Coastguard Worker                                              1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
175*89c4ff92SAndroid Build Coastguard Worker                                              3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
176*89c4ff92SAndroid Build Coastguard Worker                                              5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
177*89c4ff92SAndroid Build Coastguard Worker                                          });
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker }
181