xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/DepthwiseConvolution2dEndToEndTests.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "EndToEndTestImpl.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/DataLayoutUtils.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <map>
16*89c4ff92SAndroid Build Coastguard Worker #include <vector>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker namespace
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker 
CreateDepthwiseConvolution2dNetwork(const armnn::DepthwiseConvolution2dDescriptor & descriptor,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & weightsInfo,const armnn::TensorInfo & biasInfo,const armnn::TensorInfo & outputInfo,const armnn::ConstTensor & weights,const armnn::ConstTensor & biases)21*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateDepthwiseConvolution2dNetwork(const armnn::DepthwiseConvolution2dDescriptor& descriptor,
22*89c4ff92SAndroid Build Coastguard Worker                                                        const armnn::TensorInfo& inputInfo,
23*89c4ff92SAndroid Build Coastguard Worker                                                        const armnn::TensorInfo& weightsInfo,
24*89c4ff92SAndroid Build Coastguard Worker                                                        const armnn::TensorInfo& biasInfo,
25*89c4ff92SAndroid Build Coastguard Worker                                                        const armnn::TensorInfo& outputInfo,
26*89c4ff92SAndroid Build Coastguard Worker                                                        const armnn::ConstTensor& weights,
27*89c4ff92SAndroid Build Coastguard Worker                                                        const armnn::ConstTensor& biases)
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network(INetwork::Create());
32*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = network->AddInputLayer(0, "input");
33*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* weightsLayer = network->AddConstantLayer(weights, "Weights");
34*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* biasLayer = network->AddConstantLayer(biases, "Bias");
35*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* convolution2d = network->AddDepthwiseConvolution2dLayer(descriptor, "depthwiseConvolution2d");
36*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = network->AddOutputLayer(0, "output");
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     Connect(input, convolution2d, inputInfo, 0, 0);
39*89c4ff92SAndroid Build Coastguard Worker     Connect(weightsLayer, convolution2d, weightsInfo, 0, 1);
40*89c4ff92SAndroid Build Coastguard Worker     Connect(biasLayer, convolution2d, biasInfo, 0, 2);
41*89c4ff92SAndroid Build Coastguard Worker     Connect(convolution2d, output, outputInfo, 0, 0);
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker     return network;
44*89c4ff92SAndroid Build Coastguard Worker }
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, armnn::DataType ArmnnBType>
DepthwiseConvolution2dEndToEnd(const std::vector<armnn::BackendId> & backends,armnn::DataLayout dataLayout)49*89c4ff92SAndroid Build Coastguard Worker void DepthwiseConvolution2dEndToEnd(const std::vector<armnn::BackendId>& backends,
50*89c4ff92SAndroid Build Coastguard Worker                                     armnn::DataLayout dataLayout)
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
53*89c4ff92SAndroid Build Coastguard Worker     using T  = ResolveType<ArmnnType>;
54*89c4ff92SAndroid Build Coastguard Worker     using BT = ResolveType<ArmnnBType>;
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     const float   qScale  = IsQuantizedType<T>() ? 0.25f : 1.0f;
57*89c4ff92SAndroid Build Coastguard Worker     const int32_t qOffset = IsQuantizedType<T>() ? 50    : 0;
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     unsigned int depthMultiplier = 2;
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputHeight    = 8;
62*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputWidth     = 16;
63*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputChannels  = 2;
64*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputBatchSize = 1;
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     unsigned int kernelHeight = 5;
67*89c4ff92SAndroid Build Coastguard Worker     unsigned int kernelWidth  = 3;
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputHeight    = inputHeight - kernelHeight + 1 + 2;
70*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputWidth     = (inputWidth - kernelWidth + 1)/2;
71*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputChannels  = inputChannels * depthMultiplier;
72*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputBatchSize = inputBatchSize;
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputInfo({ inputBatchSize, inputChannels, inputHeight, inputWidth }, ArmnnType, qScale, qOffset, true);
75*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputInfo({ outputBatchSize, outputChannels, outputHeight, outputWidth }, ArmnnType, qScale, qOffset);
76*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightsInfo({1, kernelHeight, kernelWidth, outputChannels}, ArmnnType, qScale, qOffset, true);
77*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasesInfo({outputChannels}, ArmnnBType, qScale * qScale, 0, true);
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData =
80*89c4ff92SAndroid Build Coastguard Worker     {
81*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
82*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
83*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
84*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
85*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
86*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
87*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
88*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
89*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
90*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
91*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
92*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
93*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
94*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
95*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
96*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f
97*89c4ff92SAndroid Build Coastguard Worker    };
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData =
100*89c4ff92SAndroid Build Coastguard Worker     {
101*89c4ff92SAndroid Build Coastguard Worker         1.0f,  1.0f, 1.0f,
102*89c4ff92SAndroid Build Coastguard Worker         1.0f, -1.0f, 1.0f,
103*89c4ff92SAndroid Build Coastguard Worker         1.0f,  1.0f, 1.0f,
104*89c4ff92SAndroid Build Coastguard Worker         1.0f,  1.0f, 1.0f,
105*89c4ff92SAndroid Build Coastguard Worker         1.0f,  1.0f, 1.0f,
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker         2.0f,  2.0f, 2.0f,
108*89c4ff92SAndroid Build Coastguard Worker         2.0f,  2.0f, 2.0f,
109*89c4ff92SAndroid Build Coastguard Worker         2.0f,  2.0f, 2.0f,
110*89c4ff92SAndroid Build Coastguard Worker         2.0f,  2.0f, 2.0f,
111*89c4ff92SAndroid Build Coastguard Worker         2.0f,  2.0f, 2.0f,
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker         0.0f,  0.0f, 0.0f,
114*89c4ff92SAndroid Build Coastguard Worker         0.0f, -1.0f, 0.0f,
115*89c4ff92SAndroid Build Coastguard Worker         0.0f,  0.0f, 0.0f,
116*89c4ff92SAndroid Build Coastguard Worker         0.0f,  0.0f, 0.0f,
117*89c4ff92SAndroid Build Coastguard Worker         0.0f,  0.0f, 0.0f,
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker         0.0f,  0.0f, 0.0f,
120*89c4ff92SAndroid Build Coastguard Worker         0.0f,  0.0f, 0.0f,
121*89c4ff92SAndroid Build Coastguard Worker         0.0f,  1.0f, 0.0f,
122*89c4ff92SAndroid Build Coastguard Worker         0.0f,  0.0f, 0.0f,
123*89c4ff92SAndroid Build Coastguard Worker         0.0f,  0.0f, 0.0f
124*89c4ff92SAndroid Build Coastguard Worker     };
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasesData = { 0.0f, 2.0f, 1.0f, -1.0f };
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData =
129*89c4ff92SAndroid Build Coastguard Worker     {
130*89c4ff92SAndroid Build Coastguard Worker         3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f,
131*89c4ff92SAndroid Build Coastguard Worker         5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.5f, 5.5f, 5.5f, 5.5f, 5.5f, 5.5f, 5.5f,
132*89c4ff92SAndroid Build Coastguard Worker         5.5f, 5.5f, 5.5f, 5.5f, 5.5f, 5.5f, 5.5f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f,
133*89c4ff92SAndroid Build Coastguard Worker         2.5f, 2.5f, 2.5f, 2.5f, 2.5f, 2.5f, 2.5f, 3.5f, 3.5f, 3.5f, 3.5f, 3.5f, 3.5f, 3.5f,
134*89c4ff92SAndroid Build Coastguard Worker         4.5f, 4.5f, 4.5f, 4.5f, 4.5f, 4.5f, 4.5f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f,
135*89c4ff92SAndroid Build Coastguard Worker         6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f,
136*89c4ff92SAndroid Build Coastguard Worker         1.0f, 3.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f, 4.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
137*89c4ff92SAndroid Build Coastguard Worker         2.0f, 4.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f, 4.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
138*89c4ff92SAndroid Build Coastguard Worker         2.0f, 4.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f, 4.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
139*89c4ff92SAndroid Build Coastguard Worker         2.0f, 4.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 5.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
140*89c4ff92SAndroid Build Coastguard Worker         3.0f, 5.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 5.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
141*89c4ff92SAndroid Build Coastguard Worker         3.0f, 5.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 5.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f
142*89c4ff92SAndroid Build Coastguard Worker    };
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker     DepthwiseConvolution2dDescriptor descriptor;
145*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 0;
146*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 0;
147*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 1;
148*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 0;
149*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 2;
150*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 1;
151*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
152*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = dataLayout;
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker     // Permute input and output if NCDHW.
155*89c4ff92SAndroid Build Coastguard Worker     if (dataLayout == DataLayout::NCHW)
156*89c4ff92SAndroid Build Coastguard Worker     {
157*89c4ff92SAndroid Build Coastguard Worker         PermuteTensorNhwcToNchw(inputInfo, inputData);
158*89c4ff92SAndroid Build Coastguard Worker         PermuteTensorNhwcToNchw(outputInfo, expectedOutputData);
159*89c4ff92SAndroid Build Coastguard Worker     }
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker     // Quantize data
162*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qInputData          = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
163*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qWeightsData        = armnnUtils::QuantizedVector<T>(weightsData, qScale, qOffset);
164*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qExpectedOutputData = armnnUtils::QuantizedVector<T>(expectedOutputData, qScale, qOffset);
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker     std::vector<BT> qBiasesData = armnnUtils::QuantizedVector<BT>(biasesData, qScale * qScale, 0);
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker     ConstTensor weights(weightsInfo, qWeightsData);
169*89c4ff92SAndroid Build Coastguard Worker     ConstTensor biases(biasesInfo, qBiasesData);
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreateDepthwiseConvolution2dNetwork(descriptor,
172*89c4ff92SAndroid Build Coastguard Worker                                                               inputInfo,
173*89c4ff92SAndroid Build Coastguard Worker                                                               weightsInfo,
174*89c4ff92SAndroid Build Coastguard Worker                                                               biasesInfo,
175*89c4ff92SAndroid Build Coastguard Worker                                                               outputInfo,
176*89c4ff92SAndroid Build Coastguard Worker                                                               weights,
177*89c4ff92SAndroid Build Coastguard Worker                                                               biases);
178*89c4ff92SAndroid Build Coastguard Worker 
179*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
180*89c4ff92SAndroid Build Coastguard Worker                                                 { { 0, qInputData } },
181*89c4ff92SAndroid Build Coastguard Worker                                                 { { 0, qExpectedOutputData } },
182*89c4ff92SAndroid Build Coastguard Worker                                                 backends);
183*89c4ff92SAndroid Build Coastguard Worker }
184