xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/DepthToSpaceEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/DataLayoutUtils.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
CreateDepthToSpaceNetwork(const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputInfo,const armnn::DepthToSpaceDescriptor & descriptor)18*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateDepthToSpaceNetwork(const armnn::TensorInfo& inputInfo,
19*89c4ff92SAndroid Build Coastguard Worker                                              const armnn::TensorInfo& outputInfo,
20*89c4ff92SAndroid Build Coastguard Worker                                              const armnn::DepthToSpaceDescriptor& descriptor)
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network(INetwork::Create());
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input        = network->AddInputLayer(0, "input");
27*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* depthToSpace = network->AddDepthToSpaceLayer(descriptor, "depthToSpace");
28*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output       = network->AddOutputLayer(0, "output");
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker     Connect(input, depthToSpace, inputInfo, 0, 0);
31*89c4ff92SAndroid Build Coastguard Worker     Connect(depthToSpace, output, outputInfo, 0, 0);
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     return network;
34*89c4ff92SAndroid Build Coastguard Worker }
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
DepthToSpaceEndToEndImpl(const std::vector<armnn::BackendId> & backends,const DepthToSpaceDescriptor & descriptor,const armnn::TensorShape & nhwcInputShape,const armnn::TensorShape & nhwcOutputShape,const std::vector<float> & floatInputData,const std::vector<float> & floatExpectedOutputData)37*89c4ff92SAndroid Build Coastguard Worker void DepthToSpaceEndToEndImpl(const std::vector<armnn::BackendId>& backends,
38*89c4ff92SAndroid Build Coastguard Worker                               const DepthToSpaceDescriptor& descriptor,
39*89c4ff92SAndroid Build Coastguard Worker                               const armnn::TensorShape& nhwcInputShape,
40*89c4ff92SAndroid Build Coastguard Worker                               const armnn::TensorShape& nhwcOutputShape,
41*89c4ff92SAndroid Build Coastguard Worker                               const std::vector<float>& floatInputData,
42*89c4ff92SAndroid Build Coastguard Worker                               const std::vector<float>& floatExpectedOutputData)
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputInfo(nhwcInputShape, ArmnnType);
47*89c4ff92SAndroid Build Coastguard Worker     inputInfo.SetConstant(true);
48*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputInfo(nhwcOutputShape, ArmnnType);
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker     constexpr float   qScale  = 0.25f;
51*89c4ff92SAndroid Build Coastguard Worker     constexpr int32_t qOffset = 128;
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     // Set quantization parameters for quantized types
54*89c4ff92SAndroid Build Coastguard Worker     if (IsQuantizedType<T>())
55*89c4ff92SAndroid Build Coastguard Worker     {
56*89c4ff92SAndroid Build Coastguard Worker         inputInfo.SetQuantizationScale(qScale);
57*89c4ff92SAndroid Build Coastguard Worker         inputInfo.SetQuantizationOffset(qOffset);
58*89c4ff92SAndroid Build Coastguard Worker         outputInfo.SetQuantizationScale(qScale);
59*89c4ff92SAndroid Build Coastguard Worker         outputInfo.SetQuantizationOffset(qOffset);
60*89c4ff92SAndroid Build Coastguard Worker     }
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputData          = armnnUtils::QuantizedVector<T>(floatInputData, qScale, qOffset);
63*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputData = armnnUtils::QuantizedVector<T>(floatExpectedOutputData, qScale, qOffset);
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker     // Permute tensors from NHWC to NCHW (if needed)
66*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.m_DataLayout == DataLayout::NCHW)
67*89c4ff92SAndroid Build Coastguard Worker     {
68*89c4ff92SAndroid Build Coastguard Worker         PermuteTensorNhwcToNchw(inputInfo, inputData);
69*89c4ff92SAndroid Build Coastguard Worker         PermuteTensorNhwcToNchw(outputInfo, expectedOutputData);
70*89c4ff92SAndroid Build Coastguard Worker     }
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreateDepthToSpaceNetwork(inputInfo, outputInfo, descriptor);
73*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
74*89c4ff92SAndroid Build Coastguard Worker                                                 { { 0, inputData } },
75*89c4ff92SAndroid Build Coastguard Worker                                                 { { 0, expectedOutputData } },
76*89c4ff92SAndroid Build Coastguard Worker                                                 backends);
77*89c4ff92SAndroid Build Coastguard Worker }
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
DepthToSpaceEndToEnd(const std::vector<armnn::BackendId> & defaultBackends,armnn::DataLayout dataLayout)82*89c4ff92SAndroid Build Coastguard Worker void DepthToSpaceEndToEnd(const std::vector<armnn::BackendId>& defaultBackends,
83*89c4ff92SAndroid Build Coastguard Worker                           armnn::DataLayout dataLayout)
84*89c4ff92SAndroid Build Coastguard Worker {
85*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     TensorShape inputShape  = { 2, 2, 2, 4 };
88*89c4ff92SAndroid Build Coastguard Worker     TensorShape outputShape = { 2, 4, 4, 1 };
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData =
91*89c4ff92SAndroid Build Coastguard Worker     {
92*89c4ff92SAndroid Build Coastguard Worker          1.f,  2.f,  3.f,  4.f,
93*89c4ff92SAndroid Build Coastguard Worker          5.f,  6.f,  7.f,  8.f,
94*89c4ff92SAndroid Build Coastguard Worker          9.f, 10.f, 11.f, 12.f,
95*89c4ff92SAndroid Build Coastguard Worker         13.f, 14.f, 15.f, 16.f,
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker         17.f, 18.f, 19.f, 20.f,
98*89c4ff92SAndroid Build Coastguard Worker         21.f, 22.f, 23.f, 24.f,
99*89c4ff92SAndroid Build Coastguard Worker         25.f, 26.f, 27.f, 28.f,
100*89c4ff92SAndroid Build Coastguard Worker         29.f, 30.f, 31.f, 32.f
101*89c4ff92SAndroid Build Coastguard Worker     };
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData =
104*89c4ff92SAndroid Build Coastguard Worker     {
105*89c4ff92SAndroid Build Coastguard Worker          1.f,  2.f,  5.f,  6.f,
106*89c4ff92SAndroid Build Coastguard Worker          3.f,  4.f,  7.f,  8.f,
107*89c4ff92SAndroid Build Coastguard Worker          9.f, 10.f, 13.f, 14.f,
108*89c4ff92SAndroid Build Coastguard Worker         11.f, 12.f, 15.f, 16.f,
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker         17.f, 18.f, 21.f, 22.f,
111*89c4ff92SAndroid Build Coastguard Worker         19.f, 20.f, 23.f, 24.f,
112*89c4ff92SAndroid Build Coastguard Worker         25.f, 26.f, 29.f, 30.f,
113*89c4ff92SAndroid Build Coastguard Worker         27.f, 28.f, 31.f, 32.f
114*89c4ff92SAndroid Build Coastguard Worker     };
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     DepthToSpaceEndToEndImpl<ArmnnType>(defaultBackends,
117*89c4ff92SAndroid Build Coastguard Worker                                         DepthToSpaceDescriptor(2, dataLayout),
118*89c4ff92SAndroid Build Coastguard Worker                                         inputShape,
119*89c4ff92SAndroid Build Coastguard Worker                                         outputShape,
120*89c4ff92SAndroid Build Coastguard Worker                                         inputData,
121*89c4ff92SAndroid Build Coastguard Worker                                         expectedOutputData);
122*89c4ff92SAndroid Build Coastguard Worker }
123