xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/Convolution2dEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "EndToEndTestImpl.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/DataLayoutUtils.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <map>
16*89c4ff92SAndroid Build Coastguard Worker #include <vector>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker namespace
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker 
CreateConstConvolution2dNetwork(const armnn::Convolution2dDescriptor & descriptor,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & weightsInfo,const armnn::TensorInfo & biasInfo,const armnn::TensorInfo & outputInfo,const armnn::ConstTensor & weights,const armnn::ConstTensor & biases,bool biasEnabled)21*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateConstConvolution2dNetwork(const armnn::Convolution2dDescriptor& descriptor,
22*89c4ff92SAndroid Build Coastguard Worker                                                    const armnn::TensorInfo& inputInfo,
23*89c4ff92SAndroid Build Coastguard Worker                                                    const armnn::TensorInfo& weightsInfo,
24*89c4ff92SAndroid Build Coastguard Worker                                                    const armnn::TensorInfo& biasInfo,
25*89c4ff92SAndroid Build Coastguard Worker                                                    const armnn::TensorInfo& outputInfo,
26*89c4ff92SAndroid Build Coastguard Worker                                                    const armnn::ConstTensor& weights,
27*89c4ff92SAndroid Build Coastguard Worker                                                    const armnn::ConstTensor& biases,
28*89c4ff92SAndroid Build Coastguard Worker                                                    bool biasEnabled)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network(INetwork::Create());
33*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = network->AddInputLayer(0, "input");
34*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* weightsLayer = network->AddConstantLayer(weights, "Weights");
35*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* convolution2d = network->AddConvolution2dLayer(descriptor, "convolution2d");
36*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = network->AddOutputLayer(0, "output");
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     Connect(input, convolution2d, inputInfo, 0, 0);
39*89c4ff92SAndroid Build Coastguard Worker     Connect(weightsLayer, convolution2d, weightsInfo, 0, 1);
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker     if(biasEnabled)
42*89c4ff92SAndroid Build Coastguard Worker     {
43*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* biasLayer = network->AddConstantLayer(biases, "Bias");
44*89c4ff92SAndroid Build Coastguard Worker         Connect(biasLayer, convolution2d, biasInfo, 0, 2);
45*89c4ff92SAndroid Build Coastguard Worker     }
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     Connect(convolution2d, output, outputInfo, 0, 0);
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     return network;
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
Convolution2dEndToEnd(const std::vector<armnn::BackendId> & backends,armnn::DataLayout dataLayout,bool biasEnabled=true)53*89c4ff92SAndroid Build Coastguard Worker void Convolution2dEndToEnd(const std::vector<armnn::BackendId>& backends,
54*89c4ff92SAndroid Build Coastguard Worker                            armnn::DataLayout dataLayout,
55*89c4ff92SAndroid Build Coastguard Worker                            bool biasEnabled = true)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     const float   qScale  = IsQuantizedType<T>() ? 0.25f : 1.0f;
60*89c4ff92SAndroid Build Coastguard Worker     const int32_t qOffset = IsQuantizedType<T>() ? 50    : 0;
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputInfo({ 1, 5, 5, 1 }, ArmnnType, qScale, qOffset, true);
63*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputInfo({ 1, 3, 3, 1 }, ArmnnType, qScale, qOffset);
64*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightsInfo({ 1, 3, 3, 1 }, ArmnnType, qScale, qOffset, true);
65*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasesInfo({ 1 }, ArmnnType, qScale * qScale, 0, true);
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData =
68*89c4ff92SAndroid Build Coastguard Worker     {
69*89c4ff92SAndroid Build Coastguard Worker         1.0f, 5.0f, 2.0f, 3.0f, 5.0f,
70*89c4ff92SAndroid Build Coastguard Worker         8.0f, 7.0f, 3.0f, 6.0f, 3.0f,
71*89c4ff92SAndroid Build Coastguard Worker         3.0f, 3.0f, 9.0f, 1.0f, 9.0f,
72*89c4ff92SAndroid Build Coastguard Worker         4.0f, 1.0f, 8.0f, 1.0f, 3.0f,
73*89c4ff92SAndroid Build Coastguard Worker         6.0f, 8.0f, 1.0f, 9.0f, 2.0f
74*89c4ff92SAndroid Build Coastguard Worker     };
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData =
77*89c4ff92SAndroid Build Coastguard Worker     {
78*89c4ff92SAndroid Build Coastguard Worker         4.0f, 5.0f, 6.0f,
79*89c4ff92SAndroid Build Coastguard Worker         0.0f, 0.0f, 0.0f,
80*89c4ff92SAndroid Build Coastguard Worker         3.0f, 2.0f, 1.0f
81*89c4ff92SAndroid Build Coastguard Worker     };
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasesData = { 1.0f };
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker     float bias = biasEnabled ? biasesData[0] : 0.0f;
86*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData =
87*89c4ff92SAndroid Build Coastguard Worker     {
88*89c4ff92SAndroid Build Coastguard Worker         65.0f + bias,  76.0f + bias,  91.0f + bias,
89*89c4ff92SAndroid Build Coastguard Worker         107.0f + bias, 99.0f + bias,  89.0f + bias,
90*89c4ff92SAndroid Build Coastguard Worker         116.0f + bias, 98.0f + bias,  118.0f + bias,
91*89c4ff92SAndroid Build Coastguard Worker     };
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor descriptor;
94*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 0;
95*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 0;
96*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 0;
97*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 0;
98*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 1;
99*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 1;
100*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = biasEnabled;
101*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = dataLayout;
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     if (dataLayout == DataLayout::NCHW)
104*89c4ff92SAndroid Build Coastguard Worker     {
105*89c4ff92SAndroid Build Coastguard Worker         PermuteTensorNhwcToNchw(inputInfo, inputData);
106*89c4ff92SAndroid Build Coastguard Worker         PermuteTensorNhwcToNchw(weightsInfo, weightsData);
107*89c4ff92SAndroid Build Coastguard Worker         PermuteTensorNhwcToNchw(outputInfo, expectedOutputData);
108*89c4ff92SAndroid Build Coastguard Worker     }
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker     // Quantize data
111*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qInputData          = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
112*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qWeightsData        = armnnUtils::QuantizedVector<T>(weightsData, qScale, qOffset);
113*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qExpectedOutputData = armnnUtils::QuantizedVector<T>(expectedOutputData, qScale, qOffset);
114*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qBiasesData         = armnnUtils::QuantizedVector<T>(biasesData, qScale * qScale, 0);
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     ConstTensor weights(weightsInfo, qWeightsData);
117*89c4ff92SAndroid Build Coastguard Worker     ConstTensor biases(biasesInfo, qBiasesData);
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreateConstConvolution2dNetwork(descriptor,
120*89c4ff92SAndroid Build Coastguard Worker                                                           inputInfo,
121*89c4ff92SAndroid Build Coastguard Worker                                                           weightsInfo,
122*89c4ff92SAndroid Build Coastguard Worker                                                           biasesInfo,
123*89c4ff92SAndroid Build Coastguard Worker                                                           outputInfo,
124*89c4ff92SAndroid Build Coastguard Worker                                                           weights,
125*89c4ff92SAndroid Build Coastguard Worker                                                           biases,
126*89c4ff92SAndroid Build Coastguard Worker                                                           biasEnabled);
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
129*89c4ff92SAndroid Build Coastguard Worker                                                 {{ 0, qInputData }},
130*89c4ff92SAndroid Build Coastguard Worker                                                 {{ 0, qExpectedOutputData }},
131*89c4ff92SAndroid Build Coastguard Worker                                                 backends);
132*89c4ff92SAndroid Build Coastguard Worker }
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
135