xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/LayerReleaseConstantDataTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <Graph.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadData.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
16*89c4ff92SAndroid Build Coastguard Worker using namespace std;
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker /////////////////////////////////////////////////////////////////////////////////////////////
19*89c4ff92SAndroid Build Coastguard Worker // The following test are created specifically to test ReleaseConstantData() method in the Layer
20*89c4ff92SAndroid Build Coastguard Worker // They build very simple graphs including the layer will be checked.
21*89c4ff92SAndroid Build Coastguard Worker // Checks weights and biases before the method called and after.
22*89c4ff92SAndroid Build Coastguard Worker /////////////////////////////////////////////////////////////////////////////////////////////
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("LayerReleaseConstantDataTest")
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReleaseBatchNormalizationLayerConstantDataTest")
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker     // create the layer we're testing
31*89c4ff92SAndroid Build Coastguard Worker     BatchNormalizationDescriptor layerDesc;
32*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_Eps = 0.05f;
33*89c4ff92SAndroid Build Coastguard Worker     BatchNormalizationLayer* const layer = graph.AddLayer<BatchNormalizationLayer>(layerDesc, "layer");
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo weightInfo({3}, armnn::DataType::Float32);
36*89c4ff92SAndroid Build Coastguard Worker     layer->m_Mean     = std::make_unique<ScopedTensorHandle>(weightInfo);
37*89c4ff92SAndroid Build Coastguard Worker     layer->m_Variance = std::make_unique<ScopedTensorHandle>(weightInfo);
38*89c4ff92SAndroid Build Coastguard Worker     layer->m_Beta     = std::make_unique<ScopedTensorHandle>(weightInfo);
39*89c4ff92SAndroid Build Coastguard Worker     layer->m_Gamma    = std::make_unique<ScopedTensorHandle>(weightInfo);
40*89c4ff92SAndroid Build Coastguard Worker     layer->m_Mean->Allocate();
41*89c4ff92SAndroid Build Coastguard Worker     layer->m_Variance->Allocate();
42*89c4ff92SAndroid Build Coastguard Worker     layer->m_Beta->Allocate();
43*89c4ff92SAndroid Build Coastguard Worker     layer->m_Gamma->Allocate();
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker     // create extra layers
46*89c4ff92SAndroid Build Coastguard Worker     Layer* const input = graph.AddLayer<InputLayer>(0, "input");
47*89c4ff92SAndroid Build Coastguard Worker     Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     // connect up
50*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo tensorInfo({2, 3, 1, 1}, armnn::DataType::Float32);
51*89c4ff92SAndroid Build Coastguard Worker     Connect(input, layer, tensorInfo);
52*89c4ff92SAndroid Build Coastguard Worker     Connect(layer, output, tensorInfo);
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     // check the constants that they are not NULL
55*89c4ff92SAndroid Build Coastguard Worker     CHECK(layer->m_Mean != nullptr);
56*89c4ff92SAndroid Build Coastguard Worker     CHECK(layer->m_Variance != nullptr);
57*89c4ff92SAndroid Build Coastguard Worker     CHECK(layer->m_Beta != nullptr);
58*89c4ff92SAndroid Build Coastguard Worker     CHECK(layer->m_Gamma != nullptr);
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker     // free up the constants..
61*89c4ff92SAndroid Build Coastguard Worker     layer->ReleaseConstantData();
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     // check the constants that they are NULL now
64*89c4ff92SAndroid Build Coastguard Worker     CHECK(layer->m_Mean == nullptr);
65*89c4ff92SAndroid Build Coastguard Worker     CHECK(layer->m_Variance == nullptr);
66*89c4ff92SAndroid Build Coastguard Worker     CHECK(layer->m_Beta == nullptr);
67*89c4ff92SAndroid Build Coastguard Worker     CHECK(layer->m_Gamma == nullptr);
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker  }
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReleaseConvolution2dLayerConstantDataTest")
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker     // create the layer we're testing
76*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor layerDesc;
77*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_PadLeft = 3;
78*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_PadRight = 3;
79*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_PadTop = 1;
80*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_PadBottom = 1;
81*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_StrideX = 2;
82*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_StrideY = 4;
83*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_BiasEnabled = true;
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker     auto* const convolutionLayer = graph.AddLayer<Convolution2dLayer>(layerDesc, "convolution");
86*89c4ff92SAndroid Build Coastguard Worker     auto* const weightsLayer = graph.AddLayer<ConstantLayer>("weights");
87*89c4ff92SAndroid Build Coastguard Worker     auto* const biasLayer = graph.AddLayer<ConstantLayer>("bias");
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightsInfo = TensorInfo({ 2, 3, 5, 3 }, armnn::DataType::Float32, 1.0, 0.0, true);
90*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasInfo = TensorInfo({ 2 }, GetBiasDataType(armnn::DataType::Float32), 1.0, 0.0, true);
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(weightsInfo);
93*89c4ff92SAndroid Build Coastguard Worker     biasLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(biasInfo);
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
96*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker     // create extra layers
99*89c4ff92SAndroid Build Coastguard Worker     Layer* const input = graph.AddLayer<InputLayer>(0, "input");
100*89c4ff92SAndroid Build Coastguard Worker     Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker     // connect up
103*89c4ff92SAndroid Build Coastguard Worker     Connect(input, convolutionLayer, TensorInfo({ 2, 3, 8, 16 }, armnn::DataType::Float32));
104*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot().Connect(convolutionLayer->GetInputSlot(1));
105*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot().Connect(convolutionLayer->GetInputSlot(2));
106*89c4ff92SAndroid Build Coastguard Worker     Connect(convolutionLayer, output, TensorInfo({ 2, 2, 2, 10 }, armnn::DataType::Float32));
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker     // check the constants that they are not NULL
109*89c4ff92SAndroid Build Coastguard Worker     CHECK(weightsLayer->m_LayerOutput != nullptr);
110*89c4ff92SAndroid Build Coastguard Worker     CHECK(biasLayer->m_LayerOutput != nullptr);
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker     // free up the constants.
113*89c4ff92SAndroid Build Coastguard Worker     convolutionLayer->ReleaseConstantData();
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker     // check the constants that they are still not NULL
116*89c4ff92SAndroid Build Coastguard Worker     CHECK(weightsLayer->m_LayerOutput != nullptr);
117*89c4ff92SAndroid Build Coastguard Worker     CHECK(biasLayer->m_LayerOutput != nullptr);
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReleaseDepthwiseConvolution2dLayerConstantDataTest")
121*89c4ff92SAndroid Build Coastguard Worker {
122*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker     // create the layer we're testing
125*89c4ff92SAndroid Build Coastguard Worker     DepthwiseConvolution2dDescriptor layerDesc;
126*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_PadLeft         = 3;
127*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_PadRight        = 3;
128*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_PadTop          = 1;
129*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_PadBottom       = 1;
130*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_StrideX         = 2;
131*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_StrideY         = 4;
132*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_BiasEnabled     = true;
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker     auto* const convolutionLayer = graph.AddLayer<DepthwiseConvolution2dLayer>(layerDesc, "convolution");
135*89c4ff92SAndroid Build Coastguard Worker     auto* const weightsLayer = graph.AddLayer<ConstantLayer>("weights");
136*89c4ff92SAndroid Build Coastguard Worker     auto* const biasLayer = graph.AddLayer<ConstantLayer>("bias");
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightsInfo = TensorInfo({ 3, 3, 5, 3 }, armnn::DataType::Float32, 1.0, 0.0, true);
139*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasInfo = TensorInfo({ 9 }, GetBiasDataType(armnn::DataType::Float32), 1.0, 0.0, true);
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(weightsInfo);
142*89c4ff92SAndroid Build Coastguard Worker     biasLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(biasInfo);
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
145*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
146*89c4ff92SAndroid Build Coastguard Worker 
147*89c4ff92SAndroid Build Coastguard Worker     // create extra layers
148*89c4ff92SAndroid Build Coastguard Worker     Layer* const input = graph.AddLayer<InputLayer>(0, "input");
149*89c4ff92SAndroid Build Coastguard Worker     Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker     // connect up
152*89c4ff92SAndroid Build Coastguard Worker     Connect(input, convolutionLayer, TensorInfo({2, 3, 8, 16}, armnn::DataType::Float32));
153*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot().Connect(convolutionLayer->GetInputSlot(1));
154*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot().Connect(convolutionLayer->GetInputSlot(2));
155*89c4ff92SAndroid Build Coastguard Worker     Connect(convolutionLayer, output, TensorInfo({2, 9, 2, 10}, armnn::DataType::Float32));
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker     // check the constants that they are not NULL
158*89c4ff92SAndroid Build Coastguard Worker     CHECK(weightsLayer->m_LayerOutput != nullptr);
159*89c4ff92SAndroid Build Coastguard Worker     CHECK(biasLayer->m_LayerOutput != nullptr);
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker     // free up the constants.
162*89c4ff92SAndroid Build Coastguard Worker     convolutionLayer->ReleaseConstantData();
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker     // check the constants that they are still not NULL
165*89c4ff92SAndroid Build Coastguard Worker     CHECK(weightsLayer->m_LayerOutput != nullptr);
166*89c4ff92SAndroid Build Coastguard Worker     CHECK(biasLayer->m_LayerOutput != nullptr);
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReleaseFullyConnectedLayerConstantDataTest")
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
172*89c4ff92SAndroid Build Coastguard Worker 
173*89c4ff92SAndroid Build Coastguard Worker     // create the layer we're testing
174*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedDescriptor layerDesc;
175*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_BiasEnabled = true;
176*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_TransposeWeightMatrix = true;
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker     auto* const fullyConnectedLayer = graph.AddLayer<FullyConnectedLayer>(layerDesc, "layer");
179*89c4ff92SAndroid Build Coastguard Worker     auto* const weightsLayer = graph.AddLayer<ConstantLayer>("weights");
180*89c4ff92SAndroid Build Coastguard Worker     auto* const biasLayer = graph.AddLayer<ConstantLayer>("bias");
181*89c4ff92SAndroid Build Coastguard Worker 
182*89c4ff92SAndroid Build Coastguard Worker     float inputsQScale = 1.0f;
183*89c4ff92SAndroid Build Coastguard Worker     float outputQScale = 2.0f;
184*89c4ff92SAndroid Build Coastguard Worker 
185*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightsInfo = TensorInfo({ 7, 20 }, DataType::QAsymmU8, inputsQScale, 0.0, true);
186*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasInfo = TensorInfo({ 7 }, GetBiasDataType(DataType::QAsymmU8), inputsQScale, 0.0, true);
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(weightsInfo);
189*89c4ff92SAndroid Build Coastguard Worker     biasLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(biasInfo);
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
192*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
193*89c4ff92SAndroid Build Coastguard Worker 
194*89c4ff92SAndroid Build Coastguard Worker     // create extra layers
195*89c4ff92SAndroid Build Coastguard Worker     Layer* const input = graph.AddLayer<InputLayer>(0, "input");
196*89c4ff92SAndroid Build Coastguard Worker     Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker     // connect up
199*89c4ff92SAndroid Build Coastguard Worker     Connect(input, fullyConnectedLayer, TensorInfo({ 3, 1, 4, 5 }, DataType::QAsymmU8, inputsQScale));
200*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot().Connect(fullyConnectedLayer->GetInputSlot(1));
201*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot().Connect(fullyConnectedLayer->GetInputSlot(2));
202*89c4ff92SAndroid Build Coastguard Worker     Connect(fullyConnectedLayer, output, TensorInfo({ 3, 7 }, DataType::QAsymmU8, outputQScale));
203*89c4ff92SAndroid Build Coastguard Worker 
204*89c4ff92SAndroid Build Coastguard Worker     // check the constants that they are not NULL
205*89c4ff92SAndroid Build Coastguard Worker     CHECK(weightsLayer->m_LayerOutput != nullptr);
206*89c4ff92SAndroid Build Coastguard Worker     CHECK(biasLayer->m_LayerOutput != nullptr);
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker     // free up the constants.
209*89c4ff92SAndroid Build Coastguard Worker     fullyConnectedLayer->ReleaseConstantData();
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker     // check the constants that they are still not NULL
212*89c4ff92SAndroid Build Coastguard Worker     CHECK(weightsLayer->m_LayerOutput != nullptr);
213*89c4ff92SAndroid Build Coastguard Worker     CHECK(biasLayer->m_LayerOutput != nullptr);
214*89c4ff92SAndroid Build Coastguard Worker }
215*89c4ff92SAndroid Build Coastguard Worker 
216*89c4ff92SAndroid Build Coastguard Worker }
217*89c4ff92SAndroid Build Coastguard Worker 
218