xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/SpaceToDepthEndToEndTestImpl.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "SpaceToDepthEndToEndTestImpl.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ResolveType.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "EndToEndTestImpl.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/DataLayoutIndexed.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/DataLayoutUtils.hpp>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <TestUtils.hpp>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker namespace
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker template<typename armnn::DataType DataType>
CreateSpaceToDepthNetwork(const armnn::TensorShape & inputShape,const armnn::TensorShape & outputShape,const armnn::DataLayout dataLayout,unsigned int blockSize,const float qScale=1.0f,const int32_t qOffset=0)25*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateSpaceToDepthNetwork(const armnn::TensorShape& inputShape,
26*89c4ff92SAndroid Build Coastguard Worker                                              const armnn::TensorShape& outputShape,
27*89c4ff92SAndroid Build Coastguard Worker                                              const armnn::DataLayout dataLayout,
28*89c4ff92SAndroid Build Coastguard Worker                                              unsigned int blockSize,
29*89c4ff92SAndroid Build Coastguard Worker                                              const float qScale = 1.0f,
30*89c4ff92SAndroid Build Coastguard Worker                                              const int32_t qOffset = 0)
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
35*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo(inputShape, DataType, qScale, qOffset, true);
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker     armnnUtils::DataLayoutIndexed dimensionIndices(dataLayout);
40*89c4ff92SAndroid Build Coastguard Worker     if (inputShape[dimensionIndices.GetHeightIndex()] % blockSize!=0
41*89c4ff92SAndroid Build Coastguard Worker         || inputShape[dimensionIndices.GetWidthIndex()] % blockSize!=0)
42*89c4ff92SAndroid Build Coastguard Worker     {
43*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException("Input shape must be divisible by block size in all spatial dimensions");
44*89c4ff92SAndroid Build Coastguard Worker     }
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     SpaceToDepthDescriptor spaceToDepthDesc;
47*89c4ff92SAndroid Build Coastguard Worker     spaceToDepthDesc.m_BlockSize = blockSize;
48*89c4ff92SAndroid Build Coastguard Worker     spaceToDepthDesc.m_DataLayout = dataLayout;
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* SpaceToDepth = net->AddSpaceToDepthLayer(spaceToDepthDesc, "SpaceToDepth");
51*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input        = net->AddInputLayer(0, "input");
52*89c4ff92SAndroid Build Coastguard Worker     Connect(input, SpaceToDepth, inputTensorInfo, 0, 0);
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo(outputShape, DataType, qScale, qOffset);
55*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
56*89c4ff92SAndroid Build Coastguard Worker     Connect(SpaceToDepth, output, outputTensorInfo, 0, 0);
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     return net;
59*89c4ff92SAndroid Build Coastguard Worker }
60*89c4ff92SAndroid Build Coastguard Worker 
SpaceToDepthEndToEnd(const std::vector<armnn::BackendId> & backends,const armnn::DataLayout & dataLayout,armnn::TensorInfo & inputTensorInfo,armnn::TensorInfo & outputTensorInfo,std::vector<float> & inputData,std::vector<float> & expectedOutputData,const unsigned int blockSize)61*89c4ff92SAndroid Build Coastguard Worker void SpaceToDepthEndToEnd(const std::vector<armnn::BackendId>& backends,
62*89c4ff92SAndroid Build Coastguard Worker                           const armnn::DataLayout& dataLayout,
63*89c4ff92SAndroid Build Coastguard Worker                           armnn::TensorInfo& inputTensorInfo,
64*89c4ff92SAndroid Build Coastguard Worker                           armnn::TensorInfo& outputTensorInfo,
65*89c4ff92SAndroid Build Coastguard Worker                           std::vector<float>& inputData,
66*89c4ff92SAndroid Build Coastguard Worker                           std::vector<float>& expectedOutputData,
67*89c4ff92SAndroid Build Coastguard Worker                           const unsigned int blockSize)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker     if (dataLayout == DataLayout::NCHW)
72*89c4ff92SAndroid Build Coastguard Worker     {
73*89c4ff92SAndroid Build Coastguard Worker         PermuteTensorNhwcToNchw<float>(inputTensorInfo, inputData);
74*89c4ff92SAndroid Build Coastguard Worker         PermuteTensorNhwcToNchw<float>(outputTensorInfo, expectedOutputData);
75*89c4ff92SAndroid Build Coastguard Worker     }
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network
78*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net = CreateSpaceToDepthNetwork<DataType::Float32>(
79*89c4ff92SAndroid Build Coastguard Worker             inputTensorInfo.GetShape(),
80*89c4ff92SAndroid Build Coastguard Worker             outputTensorInfo.GetShape(),
81*89c4ff92SAndroid Build Coastguard Worker             dataLayout,
82*89c4ff92SAndroid Build Coastguard Worker             blockSize);
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker     CHECK(net);
85*89c4ff92SAndroid Build Coastguard Worker 
86*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<float>> inputTensorData = { { 0, inputData } };
87*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<float>> expectedOutputTensorData = { { 0, expectedOutputData } };
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<DataType::Float32, DataType::Float32>(
90*89c4ff92SAndroid Build Coastguard Worker             move(net),
91*89c4ff92SAndroid Build Coastguard Worker             inputTensorData,
92*89c4ff92SAndroid Build Coastguard Worker             expectedOutputTensorData,
93*89c4ff92SAndroid Build Coastguard Worker             backends);
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
97*89c4ff92SAndroid Build Coastguard Worker 
SpaceToDepthNhwcEndToEndTest1(const std::vector<armnn::BackendId> & defaultBackends)98*89c4ff92SAndroid Build Coastguard Worker void SpaceToDepthNhwcEndToEndTest1(const std::vector<armnn::BackendId>& defaultBackends)
99*89c4ff92SAndroid Build Coastguard Worker {
100*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker     const unsigned int blockSize = 2;
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker     TensorShape inputShape{1, 2, 2, 1};
105*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo(inputShape, DataType::Float32, 0.0f, 0, true);
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker     TensorShape outputShape{1, 1, 1, 4};
108*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo(outputShape, DataType::Float32);
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData = std::vector<float>(
111*89c4ff92SAndroid Build Coastguard Worker     {
112*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
113*89c4ff92SAndroid Build Coastguard Worker     });
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData = std::vector<float>(
116*89c4ff92SAndroid Build Coastguard Worker     {
117*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
118*89c4ff92SAndroid Build Coastguard Worker     });
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker     SpaceToDepthEndToEnd(defaultBackends,
121*89c4ff92SAndroid Build Coastguard Worker                          DataLayout::NHWC,
122*89c4ff92SAndroid Build Coastguard Worker                          inputTensorInfo,
123*89c4ff92SAndroid Build Coastguard Worker                          outputTensorInfo,
124*89c4ff92SAndroid Build Coastguard Worker                          inputData,
125*89c4ff92SAndroid Build Coastguard Worker                          expectedOutputData,
126*89c4ff92SAndroid Build Coastguard Worker                          blockSize);
127*89c4ff92SAndroid Build Coastguard Worker }
128*89c4ff92SAndroid Build Coastguard Worker 
SpaceToDepthNchwEndToEndTest1(const std::vector<armnn::BackendId> & defaultBackends)129*89c4ff92SAndroid Build Coastguard Worker void SpaceToDepthNchwEndToEndTest1(const std::vector<armnn::BackendId>& defaultBackends)
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker     const unsigned int blockSize = 2;
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker     TensorShape inputShape{1, 2, 2, 1};
136*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo(inputShape, DataType::Float32, 0.0f, 0, true);
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker     TensorShape outputShape{1, 1, 1, 4};
139*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo(outputShape, DataType::Float32);
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData = std::vector<float>(
142*89c4ff92SAndroid Build Coastguard Worker     {
143*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
144*89c4ff92SAndroid Build Coastguard Worker     });
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData = std::vector<float>(
147*89c4ff92SAndroid Build Coastguard Worker     {
148*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f
149*89c4ff92SAndroid Build Coastguard Worker     });
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker     SpaceToDepthEndToEnd(defaultBackends,
152*89c4ff92SAndroid Build Coastguard Worker                          DataLayout::NCHW,
153*89c4ff92SAndroid Build Coastguard Worker                          inputTensorInfo,
154*89c4ff92SAndroid Build Coastguard Worker                          outputTensorInfo,
155*89c4ff92SAndroid Build Coastguard Worker                          inputData,
156*89c4ff92SAndroid Build Coastguard Worker                          expectedOutputData,
157*89c4ff92SAndroid Build Coastguard Worker                          blockSize);
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker 
SpaceToDepthNhwcEndToEndTest2(const std::vector<armnn::BackendId> & defaultBackends)160*89c4ff92SAndroid Build Coastguard Worker void SpaceToDepthNhwcEndToEndTest2(const std::vector<armnn::BackendId>& defaultBackends)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker     const unsigned int blockSize = 2;
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker     TensorShape inputShape{1, 2, 2, 2};
167*89c4ff92SAndroid Build Coastguard Worker     TensorShape outputShape{1, 1, 1, 8};
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo(outputShape, DataType::Float32);
170*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo(inputShape, DataType::Float32, 0.0f, 0, true);
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData = std::vector<float>(
173*89c4ff92SAndroid Build Coastguard Worker     {
174*89c4ff92SAndroid Build Coastguard Worker         1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f
175*89c4ff92SAndroid Build Coastguard Worker     });
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData = std::vector<float>(
178*89c4ff92SAndroid Build Coastguard Worker     {
179*89c4ff92SAndroid Build Coastguard Worker         1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f
180*89c4ff92SAndroid Build Coastguard Worker     });
181*89c4ff92SAndroid Build Coastguard Worker 
182*89c4ff92SAndroid Build Coastguard Worker     SpaceToDepthEndToEnd(defaultBackends,
183*89c4ff92SAndroid Build Coastguard Worker                          DataLayout::NHWC,
184*89c4ff92SAndroid Build Coastguard Worker                          inputTensorInfo,
185*89c4ff92SAndroid Build Coastguard Worker                          outputTensorInfo,
186*89c4ff92SAndroid Build Coastguard Worker                          inputData,
187*89c4ff92SAndroid Build Coastguard Worker                          expectedOutputData,
188*89c4ff92SAndroid Build Coastguard Worker                          blockSize);
189*89c4ff92SAndroid Build Coastguard Worker }
190*89c4ff92SAndroid Build Coastguard Worker 
SpaceToDepthNchwEndToEndTest2(const std::vector<armnn::BackendId> & defaultBackends)191*89c4ff92SAndroid Build Coastguard Worker void SpaceToDepthNchwEndToEndTest2(const std::vector<armnn::BackendId>& defaultBackends)
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker     const unsigned int blockSize = 2;
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker     TensorShape inputShape{1, 2, 2, 2};
198*89c4ff92SAndroid Build Coastguard Worker     TensorShape outputShape{1, 1, 1, 8};
199*89c4ff92SAndroid Build Coastguard Worker 
200*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo(inputShape, DataType::Float32, 0.0f, 0, true);
201*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo(outputShape, DataType::Float32);
202*89c4ff92SAndroid Build Coastguard Worker 
203*89c4ff92SAndroid Build Coastguard Worker 
204*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData = std::vector<float>(
205*89c4ff92SAndroid Build Coastguard Worker     {
206*89c4ff92SAndroid Build Coastguard Worker         1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f
207*89c4ff92SAndroid Build Coastguard Worker     });
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData = std::vector<float>(
210*89c4ff92SAndroid Build Coastguard Worker     {
211*89c4ff92SAndroid Build Coastguard Worker         1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f
212*89c4ff92SAndroid Build Coastguard Worker     });
213*89c4ff92SAndroid Build Coastguard Worker 
214*89c4ff92SAndroid Build Coastguard Worker     SpaceToDepthEndToEnd(defaultBackends,
215*89c4ff92SAndroid Build Coastguard Worker                          DataLayout::NCHW,
216*89c4ff92SAndroid Build Coastguard Worker                          inputTensorInfo,
217*89c4ff92SAndroid Build Coastguard Worker                          outputTensorInfo,
218*89c4ff92SAndroid Build Coastguard Worker                          inputData,
219*89c4ff92SAndroid Build Coastguard Worker                          expectedOutputData,
220*89c4ff92SAndroid Build Coastguard Worker                          blockSize);
221*89c4ff92SAndroid Build Coastguard Worker }
222