1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "SpaceToDepthEndToEndTestImpl.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ResolveType.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "EndToEndTestImpl.hpp"
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/DataLayoutIndexed.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/DataLayoutUtils.hpp>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker #include <TestUtils.hpp>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker namespace
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker
24*89c4ff92SAndroid Build Coastguard Worker template<typename armnn::DataType DataType>
CreateSpaceToDepthNetwork(const armnn::TensorShape & inputShape,const armnn::TensorShape & outputShape,const armnn::DataLayout dataLayout,unsigned int blockSize,const float qScale=1.0f,const int32_t qOffset=0)25*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateSpaceToDepthNetwork(const armnn::TensorShape& inputShape,
26*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& outputShape,
27*89c4ff92SAndroid Build Coastguard Worker const armnn::DataLayout dataLayout,
28*89c4ff92SAndroid Build Coastguard Worker unsigned int blockSize,
29*89c4ff92SAndroid Build Coastguard Worker const float qScale = 1.0f,
30*89c4ff92SAndroid Build Coastguard Worker const int32_t qOffset = 0)
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network.
35*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create());
36*89c4ff92SAndroid Build Coastguard Worker
37*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(inputShape, DataType, qScale, qOffset, true);
38*89c4ff92SAndroid Build Coastguard Worker
39*89c4ff92SAndroid Build Coastguard Worker armnnUtils::DataLayoutIndexed dimensionIndices(dataLayout);
40*89c4ff92SAndroid Build Coastguard Worker if (inputShape[dimensionIndices.GetHeightIndex()] % blockSize!=0
41*89c4ff92SAndroid Build Coastguard Worker || inputShape[dimensionIndices.GetWidthIndex()] % blockSize!=0)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Input shape must be divisible by block size in all spatial dimensions");
44*89c4ff92SAndroid Build Coastguard Worker }
45*89c4ff92SAndroid Build Coastguard Worker
46*89c4ff92SAndroid Build Coastguard Worker SpaceToDepthDescriptor spaceToDepthDesc;
47*89c4ff92SAndroid Build Coastguard Worker spaceToDepthDesc.m_BlockSize = blockSize;
48*89c4ff92SAndroid Build Coastguard Worker spaceToDepthDesc.m_DataLayout = dataLayout;
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* SpaceToDepth = net->AddSpaceToDepthLayer(spaceToDepthDesc, "SpaceToDepth");
51*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input = net->AddInputLayer(0, "input");
52*89c4ff92SAndroid Build Coastguard Worker Connect(input, SpaceToDepth, inputTensorInfo, 0, 0);
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(outputShape, DataType, qScale, qOffset);
55*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output");
56*89c4ff92SAndroid Build Coastguard Worker Connect(SpaceToDepth, output, outputTensorInfo, 0, 0);
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker return net;
59*89c4ff92SAndroid Build Coastguard Worker }
60*89c4ff92SAndroid Build Coastguard Worker
SpaceToDepthEndToEnd(const std::vector<armnn::BackendId> & backends,const armnn::DataLayout & dataLayout,armnn::TensorInfo & inputTensorInfo,armnn::TensorInfo & outputTensorInfo,std::vector<float> & inputData,std::vector<float> & expectedOutputData,const unsigned int blockSize)61*89c4ff92SAndroid Build Coastguard Worker void SpaceToDepthEndToEnd(const std::vector<armnn::BackendId>& backends,
62*89c4ff92SAndroid Build Coastguard Worker const armnn::DataLayout& dataLayout,
63*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo& inputTensorInfo,
64*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo& outputTensorInfo,
65*89c4ff92SAndroid Build Coastguard Worker std::vector<float>& inputData,
66*89c4ff92SAndroid Build Coastguard Worker std::vector<float>& expectedOutputData,
67*89c4ff92SAndroid Build Coastguard Worker const unsigned int blockSize)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker if (dataLayout == DataLayout::NCHW)
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker PermuteTensorNhwcToNchw<float>(inputTensorInfo, inputData);
74*89c4ff92SAndroid Build Coastguard Worker PermuteTensorNhwcToNchw<float>(outputTensorInfo, expectedOutputData);
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
78*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateSpaceToDepthNetwork<DataType::Float32>(
79*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.GetShape(),
80*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo.GetShape(),
81*89c4ff92SAndroid Build Coastguard Worker dataLayout,
82*89c4ff92SAndroid Build Coastguard Worker blockSize);
83*89c4ff92SAndroid Build Coastguard Worker
84*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<float>> inputTensorData = { { 0, inputData } };
87*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<float>> expectedOutputTensorData = { { 0, expectedOutputData } };
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<DataType::Float32, DataType::Float32>(
90*89c4ff92SAndroid Build Coastguard Worker move(net),
91*89c4ff92SAndroid Build Coastguard Worker inputTensorData,
92*89c4ff92SAndroid Build Coastguard Worker expectedOutputTensorData,
93*89c4ff92SAndroid Build Coastguard Worker backends);
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
97*89c4ff92SAndroid Build Coastguard Worker
SpaceToDepthNhwcEndToEndTest1(const std::vector<armnn::BackendId> & defaultBackends)98*89c4ff92SAndroid Build Coastguard Worker void SpaceToDepthNhwcEndToEndTest1(const std::vector<armnn::BackendId>& defaultBackends)
99*89c4ff92SAndroid Build Coastguard Worker {
100*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker const unsigned int blockSize = 2;
103*89c4ff92SAndroid Build Coastguard Worker
104*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape{1, 2, 2, 1};
105*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(inputShape, DataType::Float32, 0.0f, 0, true);
106*89c4ff92SAndroid Build Coastguard Worker
107*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape{1, 1, 1, 4};
108*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(outputShape, DataType::Float32);
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData = std::vector<float>(
111*89c4ff92SAndroid Build Coastguard Worker {
112*89c4ff92SAndroid Build Coastguard Worker 1.0f, 2.0f, 3.0f, 4.0f
113*89c4ff92SAndroid Build Coastguard Worker });
114*89c4ff92SAndroid Build Coastguard Worker
115*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputData = std::vector<float>(
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker 1.0f, 2.0f, 3.0f, 4.0f
118*89c4ff92SAndroid Build Coastguard Worker });
119*89c4ff92SAndroid Build Coastguard Worker
120*89c4ff92SAndroid Build Coastguard Worker SpaceToDepthEndToEnd(defaultBackends,
121*89c4ff92SAndroid Build Coastguard Worker DataLayout::NHWC,
122*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
123*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
124*89c4ff92SAndroid Build Coastguard Worker inputData,
125*89c4ff92SAndroid Build Coastguard Worker expectedOutputData,
126*89c4ff92SAndroid Build Coastguard Worker blockSize);
127*89c4ff92SAndroid Build Coastguard Worker }
128*89c4ff92SAndroid Build Coastguard Worker
SpaceToDepthNchwEndToEndTest1(const std::vector<armnn::BackendId> & defaultBackends)129*89c4ff92SAndroid Build Coastguard Worker void SpaceToDepthNchwEndToEndTest1(const std::vector<armnn::BackendId>& defaultBackends)
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
132*89c4ff92SAndroid Build Coastguard Worker
133*89c4ff92SAndroid Build Coastguard Worker const unsigned int blockSize = 2;
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape{1, 2, 2, 1};
136*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(inputShape, DataType::Float32, 0.0f, 0, true);
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape{1, 1, 1, 4};
139*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(outputShape, DataType::Float32);
140*89c4ff92SAndroid Build Coastguard Worker
141*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData = std::vector<float>(
142*89c4ff92SAndroid Build Coastguard Worker {
143*89c4ff92SAndroid Build Coastguard Worker 1.0f, 2.0f, 3.0f, 4.0f
144*89c4ff92SAndroid Build Coastguard Worker });
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputData = std::vector<float>(
147*89c4ff92SAndroid Build Coastguard Worker {
148*89c4ff92SAndroid Build Coastguard Worker 1.0f, 2.0f, 3.0f, 4.0f
149*89c4ff92SAndroid Build Coastguard Worker });
150*89c4ff92SAndroid Build Coastguard Worker
151*89c4ff92SAndroid Build Coastguard Worker SpaceToDepthEndToEnd(defaultBackends,
152*89c4ff92SAndroid Build Coastguard Worker DataLayout::NCHW,
153*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
154*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
155*89c4ff92SAndroid Build Coastguard Worker inputData,
156*89c4ff92SAndroid Build Coastguard Worker expectedOutputData,
157*89c4ff92SAndroid Build Coastguard Worker blockSize);
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker
SpaceToDepthNhwcEndToEndTest2(const std::vector<armnn::BackendId> & defaultBackends)160*89c4ff92SAndroid Build Coastguard Worker void SpaceToDepthNhwcEndToEndTest2(const std::vector<armnn::BackendId>& defaultBackends)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker const unsigned int blockSize = 2;
165*89c4ff92SAndroid Build Coastguard Worker
166*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape{1, 2, 2, 2};
167*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape{1, 1, 1, 8};
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(outputShape, DataType::Float32);
170*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(inputShape, DataType::Float32, 0.0f, 0, true);
171*89c4ff92SAndroid Build Coastguard Worker
172*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData = std::vector<float>(
173*89c4ff92SAndroid Build Coastguard Worker {
174*89c4ff92SAndroid Build Coastguard Worker 1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f
175*89c4ff92SAndroid Build Coastguard Worker });
176*89c4ff92SAndroid Build Coastguard Worker
177*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputData = std::vector<float>(
178*89c4ff92SAndroid Build Coastguard Worker {
179*89c4ff92SAndroid Build Coastguard Worker 1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f
180*89c4ff92SAndroid Build Coastguard Worker });
181*89c4ff92SAndroid Build Coastguard Worker
182*89c4ff92SAndroid Build Coastguard Worker SpaceToDepthEndToEnd(defaultBackends,
183*89c4ff92SAndroid Build Coastguard Worker DataLayout::NHWC,
184*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
185*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
186*89c4ff92SAndroid Build Coastguard Worker inputData,
187*89c4ff92SAndroid Build Coastguard Worker expectedOutputData,
188*89c4ff92SAndroid Build Coastguard Worker blockSize);
189*89c4ff92SAndroid Build Coastguard Worker }
190*89c4ff92SAndroid Build Coastguard Worker
SpaceToDepthNchwEndToEndTest2(const std::vector<armnn::BackendId> & defaultBackends)191*89c4ff92SAndroid Build Coastguard Worker void SpaceToDepthNchwEndToEndTest2(const std::vector<armnn::BackendId>& defaultBackends)
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
194*89c4ff92SAndroid Build Coastguard Worker
195*89c4ff92SAndroid Build Coastguard Worker const unsigned int blockSize = 2;
196*89c4ff92SAndroid Build Coastguard Worker
197*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape{1, 2, 2, 2};
198*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape{1, 1, 1, 8};
199*89c4ff92SAndroid Build Coastguard Worker
200*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(inputShape, DataType::Float32, 0.0f, 0, true);
201*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(outputShape, DataType::Float32);
202*89c4ff92SAndroid Build Coastguard Worker
203*89c4ff92SAndroid Build Coastguard Worker
204*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData = std::vector<float>(
205*89c4ff92SAndroid Build Coastguard Worker {
206*89c4ff92SAndroid Build Coastguard Worker 1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f
207*89c4ff92SAndroid Build Coastguard Worker });
208*89c4ff92SAndroid Build Coastguard Worker
209*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputData = std::vector<float>(
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker 1.4f, 2.3f, 3.2f, 4.1f, 5.4f, 6.3f, 7.2f, 8.1f
212*89c4ff92SAndroid Build Coastguard Worker });
213*89c4ff92SAndroid Build Coastguard Worker
214*89c4ff92SAndroid Build Coastguard Worker SpaceToDepthEndToEnd(defaultBackends,
215*89c4ff92SAndroid Build Coastguard Worker DataLayout::NCHW,
216*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
217*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
218*89c4ff92SAndroid Build Coastguard Worker inputData,
219*89c4ff92SAndroid Build Coastguard Worker expectedOutputData,
220*89c4ff92SAndroid Build Coastguard Worker blockSize);
221*89c4ff92SAndroid Build Coastguard Worker }
222