xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/BatchToSpaceNdEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <vector>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker namespace
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker template<typename armnn::DataType DataType>
CreateBatchToSpaceNdNetwork(const armnn::TensorShape & inputShape,const armnn::TensorShape & outputShape,std::vector<unsigned int> & blockShape,std::vector<std::pair<unsigned int,unsigned int>> & crops,armnn::DataLayout dataLayout,const float qScale=1.0f,const int32_t qOffset=0)21*89c4ff92SAndroid Build Coastguard Worker INetworkPtr CreateBatchToSpaceNdNetwork(const armnn::TensorShape& inputShape,
22*89c4ff92SAndroid Build Coastguard Worker                                         const armnn::TensorShape& outputShape,
23*89c4ff92SAndroid Build Coastguard Worker                                         std::vector<unsigned int>& blockShape,
24*89c4ff92SAndroid Build Coastguard Worker                                         std::vector<std::pair<unsigned int, unsigned int>>& crops,
25*89c4ff92SAndroid Build Coastguard Worker                                         armnn::DataLayout dataLayout,
26*89c4ff92SAndroid Build Coastguard Worker                                         const float qScale = 1.0f,
27*89c4ff92SAndroid Build Coastguard Worker                                         const int32_t qOffset = 0)
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
30*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
31*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo(inputShape, DataType, qScale, qOffset, true);
34*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo(outputShape, DataType, qScale, qOffset);
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker     BatchToSpaceNdDescriptor batchToSpaceNdDesc(blockShape, crops);
37*89c4ff92SAndroid Build Coastguard Worker     batchToSpaceNdDesc.m_DataLayout = dataLayout;
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* batchToSpaceNd = net->AddBatchToSpaceNdLayer(batchToSpaceNdDesc, "batchToSpaceNd");
40*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0, "input");
41*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker     Connect(batchToSpaceNd, output, outputTensorInfo, 0, 0);
44*89c4ff92SAndroid Build Coastguard Worker     Connect(input, batchToSpaceNd, inputTensorInfo, 0, 0);
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     return net;
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
BatchToSpaceNdEndToEnd(const std::vector<BackendId> & backends,armnn::DataLayout dataLayout)50*89c4ff92SAndroid Build Coastguard Worker void BatchToSpaceNdEndToEnd(const std::vector<BackendId>& backends, armnn::DataLayout dataLayout)
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
53*89c4ff92SAndroid Build Coastguard Worker     using T = ResolveType<ArmnnType>;
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> blockShape {2, 2};
56*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::pair<unsigned int, unsigned int>> crops = {{0, 0}, {0, 0}};
57*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& inputShape  = { 4, 1, 1, 1 };
58*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& outputShape = (dataLayout == DataLayout::NCHW)
59*89c4ff92SAndroid Build Coastguard Worker                                      ? std::initializer_list<unsigned int>({ 1, 1, 2, 2 })
60*89c4ff92SAndroid Build Coastguard Worker                                      : std::initializer_list<unsigned int>({ 1, 2, 2, 1 });
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network
63*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net = CreateBatchToSpaceNdNetwork<ArmnnType>(inputShape, outputShape, blockShape, crops, dataLayout);
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker     CHECK(net);
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output.
68*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputData{ 1, 2, 3, 4 };
69*89c4ff92SAndroid Build Coastguard Worker 
70*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutput{ 1, 2, 3, 4 };
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> inputTensorData = { { 0, inputData } };
73*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> expectedOutputData = { { 0, expectedOutput } };
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
BatchToSpaceNdComplexEndToEnd(const std::vector<BackendId> & backends,armnn::DataLayout dataLayout)79*89c4ff92SAndroid Build Coastguard Worker void BatchToSpaceNdComplexEndToEnd(const std::vector<BackendId>& backends, armnn::DataLayout dataLayout)
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
82*89c4ff92SAndroid Build Coastguard Worker     using T = ResolveType<ArmnnType>;
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> blockShape {2, 2};
85*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::pair<unsigned int, unsigned int>> crops = {{0, 0}, {2, 0}};
86*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& inputShape  = (dataLayout == DataLayout::NCHW)
87*89c4ff92SAndroid Build Coastguard Worker                                      ? std::initializer_list<unsigned int>({ 8, 1, 1, 3 })
88*89c4ff92SAndroid Build Coastguard Worker                                      : std::initializer_list<unsigned int>({ 8, 1, 3, 1 });
89*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& outputShape = (dataLayout == DataLayout::NCHW)
90*89c4ff92SAndroid Build Coastguard Worker                                      ? std::initializer_list<unsigned int>({ 2, 1, 2, 4 })
91*89c4ff92SAndroid Build Coastguard Worker                                      : std::initializer_list<unsigned int>({ 2, 2, 4, 1 });
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network
94*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net = CreateBatchToSpaceNdNetwork<ArmnnType>(inputShape, outputShape, blockShape, crops, dataLayout);
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     CHECK(net);
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output.
99*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputData{
100*89c4ff92SAndroid Build Coastguard Worker                               0, 1, 3, 0,  9, 11,
101*89c4ff92SAndroid Build Coastguard Worker                               0, 2, 4, 0, 10, 12,
102*89c4ff92SAndroid Build Coastguard Worker                               0, 5, 7, 0, 13, 15,
103*89c4ff92SAndroid Build Coastguard Worker                               0, 6, 8, 0, 14, 16
104*89c4ff92SAndroid Build Coastguard Worker                             };
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutput{
107*89c4ff92SAndroid Build Coastguard Worker                                    1,   2,  3,  4,
108*89c4ff92SAndroid Build Coastguard Worker                                    5,   6,  7,  8,
109*89c4ff92SAndroid Build Coastguard Worker                                    9,  10, 11, 12,
110*89c4ff92SAndroid Build Coastguard Worker                                    13, 14, 15, 16
111*89c4ff92SAndroid Build Coastguard Worker                                  };
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> inputTensorData = { { 0, inputData } };
114*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> expectedOutputData = { { 0, expectedOutput } };
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
117*89c4ff92SAndroid Build Coastguard Worker }
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
120