xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeStridedSlice.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 #include <string>
10 
11 TEST_SUITE("Deserializer_StridedSlice")
12 {
13 struct StridedSliceFixture : public ParserFlatbuffersSerializeFixture
14 {
StridedSliceFixtureStridedSliceFixture15     explicit StridedSliceFixture(const std::string& inputShape,
16                                  const std::string& begin,
17                                  const std::string& end,
18                                  const std::string& stride,
19                                  const std::string& beginMask,
20                                  const std::string& endMask,
21                                  const std::string& shrinkAxisMask,
22                                  const std::string& ellipsisMask,
23                                  const std::string& newAxisMask,
24                                  const std::string& dataLayout,
25                                  const std::string& outputShape,
26                                  const std::string& dataType)
27     {
28         m_JsonString = R"(
29             {
30                 inputIds: [0],
31                 outputIds: [2],
32                 layers: [
33                     {
34                         layer_type: "InputLayer",
35                         layer: {
36                             base: {
37                                 layerBindingId: 0,
38                                 base: {
39                                     index: 0,
40                                     layerName: "InputLayer",
41                                     layerType: "Input",
42                                     inputSlots: [{
43                                         index: 0,
44                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
45                                     }],
46                                     outputSlots: [{
47                                         index: 0,
48                                         tensorInfo: {
49                                             dimensions: )" + inputShape + R"(,
50                                             dataType: )" + dataType + R"(
51                                         }
52                                     }]
53                                 }
54                             }
55                         }
56                     },
57                     {
58                         layer_type: "StridedSliceLayer",
59                         layer: {
60                             base: {
61                                 index: 1,
62                                 layerName: "StridedSliceLayer",
63                                 layerType: "StridedSlice",
64                                 inputSlots: [{
65                                     index: 0,
66                                     connection: {sourceLayerIndex:0, outputSlotIndex:0 },
67                                 }],
68                                 outputSlots: [{
69                                     index: 0,
70                                     tensorInfo: {
71                                         dimensions: )" + outputShape + R"(,
72                                         dataType: )" + dataType + R"(
73                                     }
74                                 }]
75                             },
76                             descriptor: {
77                                 begin: )" + begin + R"(,
78                                 end: )" + end + R"(,
79                                 stride: )" + stride + R"(,
80                                 beginMask: )" + beginMask + R"(,
81                                 endMask: )" + endMask + R"(,
82                                 shrinkAxisMask: )" + shrinkAxisMask + R"(,
83                                 ellipsisMask: )" + ellipsisMask + R"(,
84                                 newAxisMask: )" + newAxisMask + R"(,
85                                 dataLayout: )" + dataLayout + R"(,
86                             }
87                         }
88                     },
89                     {
90                         layer_type: "OutputLayer",
91                         layer: {
92                             base:{
93                                 layerBindingId: 2,
94                                 base: {
95                                     index: 2,
96                                     layerName: "OutputLayer",
97                                     layerType: "Output",
98                                     inputSlots: [{
99                                         index: 0,
100                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
101                                     }],
102                                     outputSlots: [{
103                                         index: 0,
104                                         tensorInfo: {
105                                             dimensions: )" + outputShape + R"(,
106                                             dataType: )" + dataType + R"(
107                                         },
108                                     }],
109                                 }
110                             }
111                         },
112                     }
113                 ]
114             }
115         )";
116         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
117     }
118 };
119 
120 struct SimpleStridedSliceFixture : StridedSliceFixture
121 {
SimpleStridedSliceFixtureSimpleStridedSliceFixture122     SimpleStridedSliceFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
123                                                       "[ 0, 0, 0, 0 ]",
124                                                       "[ 3, 2, 3, 1 ]",
125                                                       "[ 2, 2, 2, 1 ]",
126                                                       "0",
127                                                       "0",
128                                                       "0",
129                                                       "0",
130                                                       "0",
131                                                       "NCHW",
132                                                       "[ 2, 1, 2, 1 ]",
133                                                       "Float32") {}
134 };
135 
136 TEST_CASE_FIXTURE(SimpleStridedSliceFixture, "SimpleStridedSliceFloat32")
137 {
138     RunTest<4, armnn::DataType::Float32>(0,
139                                          {
140                                              1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
141                                              3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
142                                              5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
143                                          },
144                                          {
145                                              1.0f, 1.0f, 5.0f, 5.0f
146                                          });
147 }
148 
149 struct StridedSliceMaskFixture : StridedSliceFixture
150 {
StridedSliceMaskFixtureStridedSliceMaskFixture151     StridedSliceMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
152                                                     "[ 1, 1, 1, 1 ]",
153                                                     "[ 1, 1, 1, 1 ]",
154                                                     "[ 1, 1, 1, 1 ]",
155                                                     "15",
156                                                     "15",
157                                                     "0",
158                                                     "0",
159                                                     "0",
160                                                     "NCHW",
161                                                     "[ 3, 2, 3, 1 ]",
162                                                     "Float32") {}
163 };
164 
165 TEST_CASE_FIXTURE(StridedSliceMaskFixture, "StridedSliceMaskFloat32")
166 {
167     RunTest<4, armnn::DataType::Float32>(0,
168                                          {
169                                              1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
170                                              3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
171                                              5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
172                                          },
173                                          {
174                                              1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
175                                              3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
176                                              5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
177                                          });
178 }
179 
180 }
181