1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include "EndToEndTestImpl.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/DataLayoutUtils.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <map>
16*89c4ff92SAndroid Build Coastguard Worker #include <vector>
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker namespace
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker
CreateConstConvolution2dNetwork(const armnn::Convolution2dDescriptor & descriptor,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & weightsInfo,const armnn::TensorInfo & biasInfo,const armnn::TensorInfo & outputInfo,const armnn::ConstTensor & weights,const armnn::ConstTensor & biases,bool biasEnabled)21*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateConstConvolution2dNetwork(const armnn::Convolution2dDescriptor& descriptor,
22*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo,
23*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& weightsInfo,
24*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& biasInfo,
25*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo,
26*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor& weights,
27*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor& biases,
28*89c4ff92SAndroid Build Coastguard Worker bool biasEnabled)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network(INetwork::Create());
33*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input = network->AddInputLayer(0, "input");
34*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* weightsLayer = network->AddConstantLayer(weights, "Weights");
35*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* convolution2d = network->AddConvolution2dLayer(descriptor, "convolution2d");
36*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = network->AddOutputLayer(0, "output");
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker Connect(input, convolution2d, inputInfo, 0, 0);
39*89c4ff92SAndroid Build Coastguard Worker Connect(weightsLayer, convolution2d, weightsInfo, 0, 1);
40*89c4ff92SAndroid Build Coastguard Worker
41*89c4ff92SAndroid Build Coastguard Worker if(biasEnabled)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* biasLayer = network->AddConstantLayer(biases, "Bias");
44*89c4ff92SAndroid Build Coastguard Worker Connect(biasLayer, convolution2d, biasInfo, 0, 2);
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker Connect(convolution2d, output, outputInfo, 0, 0);
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker return network;
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
Convolution2dEndToEnd(const std::vector<armnn::BackendId> & backends,armnn::DataLayout dataLayout,bool biasEnabled=true)53*89c4ff92SAndroid Build Coastguard Worker void Convolution2dEndToEnd(const std::vector<armnn::BackendId>& backends,
54*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout dataLayout,
55*89c4ff92SAndroid Build Coastguard Worker bool biasEnabled = true)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker const float qScale = IsQuantizedType<T>() ? 0.25f : 1.0f;
60*89c4ff92SAndroid Build Coastguard Worker const int32_t qOffset = IsQuantizedType<T>() ? 50 : 0;
61*89c4ff92SAndroid Build Coastguard Worker
62*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo({ 1, 5, 5, 1 }, ArmnnType, qScale, qOffset, true);
63*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo({ 1, 3, 3, 1 }, ArmnnType, qScale, qOffset);
64*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsInfo({ 1, 3, 3, 1 }, ArmnnType, qScale, qOffset, true);
65*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasesInfo({ 1 }, ArmnnType, qScale * qScale, 0, true);
66*89c4ff92SAndroid Build Coastguard Worker
67*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData =
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker 1.0f, 5.0f, 2.0f, 3.0f, 5.0f,
70*89c4ff92SAndroid Build Coastguard Worker 8.0f, 7.0f, 3.0f, 6.0f, 3.0f,
71*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 9.0f, 1.0f, 9.0f,
72*89c4ff92SAndroid Build Coastguard Worker 4.0f, 1.0f, 8.0f, 1.0f, 3.0f,
73*89c4ff92SAndroid Build Coastguard Worker 6.0f, 8.0f, 1.0f, 9.0f, 2.0f
74*89c4ff92SAndroid Build Coastguard Worker };
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsData =
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker 4.0f, 5.0f, 6.0f,
79*89c4ff92SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f,
80*89c4ff92SAndroid Build Coastguard Worker 3.0f, 2.0f, 1.0f
81*89c4ff92SAndroid Build Coastguard Worker };
82*89c4ff92SAndroid Build Coastguard Worker
83*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasesData = { 1.0f };
84*89c4ff92SAndroid Build Coastguard Worker
85*89c4ff92SAndroid Build Coastguard Worker float bias = biasEnabled ? biasesData[0] : 0.0f;
86*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputData =
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker 65.0f + bias, 76.0f + bias, 91.0f + bias,
89*89c4ff92SAndroid Build Coastguard Worker 107.0f + bias, 99.0f + bias, 89.0f + bias,
90*89c4ff92SAndroid Build Coastguard Worker 116.0f + bias, 98.0f + bias, 118.0f + bias,
91*89c4ff92SAndroid Build Coastguard Worker };
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor descriptor;
94*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft = 0;
95*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadRight = 0;
96*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop = 0;
97*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBottom = 0;
98*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = 1;
99*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = 1;
100*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = biasEnabled;
101*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = dataLayout;
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker if (dataLayout == DataLayout::NCHW)
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker PermuteTensorNhwcToNchw(inputInfo, inputData);
106*89c4ff92SAndroid Build Coastguard Worker PermuteTensorNhwcToNchw(weightsInfo, weightsData);
107*89c4ff92SAndroid Build Coastguard Worker PermuteTensorNhwcToNchw(outputInfo, expectedOutputData);
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker // Quantize data
111*89c4ff92SAndroid Build Coastguard Worker std::vector<T> qInputData = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
112*89c4ff92SAndroid Build Coastguard Worker std::vector<T> qWeightsData = armnnUtils::QuantizedVector<T>(weightsData, qScale, qOffset);
113*89c4ff92SAndroid Build Coastguard Worker std::vector<T> qExpectedOutputData = armnnUtils::QuantizedVector<T>(expectedOutputData, qScale, qOffset);
114*89c4ff92SAndroid Build Coastguard Worker std::vector<T> qBiasesData = armnnUtils::QuantizedVector<T>(biasesData, qScale * qScale, 0);
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker ConstTensor weights(weightsInfo, qWeightsData);
117*89c4ff92SAndroid Build Coastguard Worker ConstTensor biases(biasesInfo, qBiasesData);
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = CreateConstConvolution2dNetwork(descriptor,
120*89c4ff92SAndroid Build Coastguard Worker inputInfo,
121*89c4ff92SAndroid Build Coastguard Worker weightsInfo,
122*89c4ff92SAndroid Build Coastguard Worker biasesInfo,
123*89c4ff92SAndroid Build Coastguard Worker outputInfo,
124*89c4ff92SAndroid Build Coastguard Worker weights,
125*89c4ff92SAndroid Build Coastguard Worker biases,
126*89c4ff92SAndroid Build Coastguard Worker biasEnabled);
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
129*89c4ff92SAndroid Build Coastguard Worker {{ 0, qInputData }},
130*89c4ff92SAndroid Build Coastguard Worker {{ 0, qExpectedOutputData }},
131*89c4ff92SAndroid Build Coastguard Worker backends);
132*89c4ff92SAndroid Build Coastguard Worker }
133*89c4ff92SAndroid Build Coastguard Worker
134*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
135