xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/ActivationEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "EndToEndTestImpl.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker namespace
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker /** Defines the acceptable tolerance of ActivationFunction-DataType combinations.
20*89c4ff92SAndroid Build Coastguard Worker  *
21*89c4ff92SAndroid Build Coastguard Worker  * @param activationFunction The activation function used
22*89c4ff92SAndroid Build Coastguard Worker  * @param dataType  Data type used
23*89c4ff92SAndroid Build Coastguard Worker  *
24*89c4ff92SAndroid Build Coastguard Worker  * @return Tolerance depending on the activation function and data type
25*89c4ff92SAndroid Build Coastguard Worker  */
GetActivationTolerance(const armnn::ActivationFunction & activationFunction,DataType dataType)26*89c4ff92SAndroid Build Coastguard Worker float GetActivationTolerance(const armnn::ActivationFunction& activationFunction, DataType dataType)
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker     constexpr float defaultTolerance = 1e-6f;
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker     switch (activationFunction)
31*89c4ff92SAndroid Build Coastguard Worker     {
32*89c4ff92SAndroid Build Coastguard Worker         // The following values are taken from ArmComputeLibrary/tests/validation/CL/ActivationLayer.cpp
33*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::Elu:
34*89c4ff92SAndroid Build Coastguard Worker             return (dataType == DataType::Float16 ? 0.01f : 0.00001f);
35*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::HardSwish:
36*89c4ff92SAndroid Build Coastguard Worker             return (dataType == DataType::Float16 ? 0.01f : defaultTolerance);
37*89c4ff92SAndroid Build Coastguard Worker         default:
38*89c4ff92SAndroid Build Coastguard Worker             return defaultTolerance;
39*89c4ff92SAndroid Build Coastguard Worker     }
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker /** Creates a network with one layer of the activation function specified in the activation descriptor.
43*89c4ff92SAndroid Build Coastguard Worker  *
44*89c4ff92SAndroid Build Coastguard Worker  * @param inputInfo  Tensor info of inputs
45*89c4ff92SAndroid Build Coastguard Worker  * @param outputInfo  Tensor info of outputs
46*89c4ff92SAndroid Build Coastguard Worker  * @param descriptor  Activation descriptor
47*89c4ff92SAndroid Build Coastguard Worker  *
48*89c4ff92SAndroid Build Coastguard Worker  * @return INetworkPtr  A pointer to the created network
49*89c4ff92SAndroid Build Coastguard Worker  */
CreateActivationNetwork(const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputInfo,const armnn::ActivationDescriptor & descriptor)50*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateActivationNetwork(const armnn::TensorInfo& inputInfo,
51*89c4ff92SAndroid Build Coastguard Worker                                            const armnn::TensorInfo& outputInfo,
52*89c4ff92SAndroid Build Coastguard Worker                                            const armnn::ActivationDescriptor& descriptor)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     char const* ActivationName = GetActivationFunctionAsCString(descriptor.m_Function);
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(0, "input");
61*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* prelu = net->AddActivationLayer(descriptor, ActivationName);
62*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     Connect(input, prelu, inputInfo, 0, 0);
65*89c4ff92SAndroid Build Coastguard Worker     Connect(prelu, output, outputInfo, 0, 0);
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     return net;
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker 
70*89c4ff92SAndroid Build Coastguard Worker /** Specifies the implementation of end to end tests for activation functions.
71*89c4ff92SAndroid Build Coastguard Worker  *
72*89c4ff92SAndroid Build Coastguard Worker  *  - Converts input data and expected-output data to the data type that is desired for the test (ArmnnType)
73*89c4ff92SAndroid Build Coastguard Worker  *  - Creates a network with one layer of the activation function specified in the activation descriptor.
74*89c4ff92SAndroid Build Coastguard Worker  *  - Executes the network on specified backends and compares results to expected output values
75*89c4ff92SAndroid Build Coastguard Worker  *
76*89c4ff92SAndroid Build Coastguard Worker  * @tparam ArmnnType  The armnn data type for the input and expected-output data
77*89c4ff92SAndroid Build Coastguard Worker  * @param backends  Backends to run test on
78*89c4ff92SAndroid Build Coastguard Worker  * @param floatInputData  Input data given as vector of float
79*89c4ff92SAndroid Build Coastguard Worker  * @param floatExpectedOutputData  Expected output data given as vector of float
80*89c4ff92SAndroid Build Coastguard Worker  * @param inputInfo  Tensor info of inputs
81*89c4ff92SAndroid Build Coastguard Worker  * @param outputInfo  Tensor info of outputs
82*89c4ff92SAndroid Build Coastguard Worker  * @param descriptor  Activation descriptor
83*89c4ff92SAndroid Build Coastguard Worker  */
84*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ActivationEndToEndImpl(const std::vector<armnn::BackendId> & backends,const std::vector<float> & floatInputData,const std::vector<float> & floatExpectedOutputData,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputInfo,const armnn::ActivationDescriptor & descriptor)85*89c4ff92SAndroid Build Coastguard Worker void ActivationEndToEndImpl(const std::vector<armnn::BackendId>& backends,
86*89c4ff92SAndroid Build Coastguard Worker                      const std::vector<float>& floatInputData,
87*89c4ff92SAndroid Build Coastguard Worker                      const std::vector<float>& floatExpectedOutputData,
88*89c4ff92SAndroid Build Coastguard Worker                      const armnn::TensorInfo&  inputInfo,
89*89c4ff92SAndroid Build Coastguard Worker                      const armnn::TensorInfo&  outputInfo,
90*89c4ff92SAndroid Build Coastguard Worker                      const armnn::ActivationDescriptor& descriptor)
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker     // Selectively quantizes/transforms float values to the needed data type
95*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputData          = armnnUtils::QuantizedVector<T>( floatInputData,
96*89c4ff92SAndroid Build Coastguard Worker                                                                         inputInfo.GetQuantizationScale(),
97*89c4ff92SAndroid Build Coastguard Worker                                                                         inputInfo.GetQuantizationOffset());
98*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> expectedOutputData = armnnUtils::QuantizedVector<T>( floatExpectedOutputData,
99*89c4ff92SAndroid Build Coastguard Worker                                                                         outputInfo.GetQuantizationScale(),
100*89c4ff92SAndroid Build Coastguard Worker                                                                         outputInfo.GetQuantizationOffset());
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net = CreateActivationNetwork(inputInfo, outputInfo, descriptor);
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> inputTensorData          = { { 0, inputData } };
105*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> expectedOutputTensorData = { { 0, expectedOutputData } };
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker     float tolerance = GetActivationTolerance(descriptor.m_Function, ArmnnType);
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net),
110*89c4ff92SAndroid Build Coastguard Worker                                                 inputTensorData,
111*89c4ff92SAndroid Build Coastguard Worker                                                 expectedOutputTensorData,
112*89c4ff92SAndroid Build Coastguard Worker                                                 backends,
113*89c4ff92SAndroid Build Coastguard Worker                                                 tolerance);
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker /** Executes an end to end test for Elu activation with specific input and expected-output data
117*89c4ff92SAndroid Build Coastguard Worker  *
118*89c4ff92SAndroid Build Coastguard Worker  * @tparam ArmnnType  The armnn data type for the input and expected-output data
119*89c4ff92SAndroid Build Coastguard Worker  * @param backends  The backends on which to run the test
120*89c4ff92SAndroid Build Coastguard Worker  */
121*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
EluEndToEndTest(const std::vector<BackendId> & backends)122*89c4ff92SAndroid Build Coastguard Worker void EluEndToEndTest(const std::vector<BackendId>& backends)
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> floatInputData{ -2.0f, -1.0f, -0.0f, 0.0f,
125*89c4ff92SAndroid Build Coastguard Worker                                         1.0f,  2.0f,  3.0f, 4.0f };
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> floatExpectedOutputData{ -0.86466471676f,  -0.63212055882f,  -0.0f, 0.0f,
128*89c4ff92SAndroid Build Coastguard Worker                                                  1.0f          ,   2.0f          ,   3.0f, 4.0f };
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker     float qScale = 1.0f;
131*89c4ff92SAndroid Build Coastguard Worker     int32_t qOffset = 0;
132*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputInfo({ 2, 2, 2, 1 }, ArmnnType, qScale, qOffset, true);
133*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo({ 2, 2, 2, 1 }, ArmnnType, qScale, qOffset);
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker     armnn::ActivationDescriptor descriptor(ActivationFunction::Elu, 1.0);
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker     ActivationEndToEndImpl<ArmnnType>(backends,
138*89c4ff92SAndroid Build Coastguard Worker                                       floatInputData,
139*89c4ff92SAndroid Build Coastguard Worker                                       floatExpectedOutputData,
140*89c4ff92SAndroid Build Coastguard Worker                                       inputInfo,
141*89c4ff92SAndroid Build Coastguard Worker                                       outputInfo,
142*89c4ff92SAndroid Build Coastguard Worker                                       descriptor);
143*89c4ff92SAndroid Build Coastguard Worker }
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker /** Executes an end to end test for HardSwish activation with specific input and expected-output data
146*89c4ff92SAndroid Build Coastguard Worker  *
147*89c4ff92SAndroid Build Coastguard Worker  * @tparam ArmnnType  The armnn data type for the input and expected-output data
148*89c4ff92SAndroid Build Coastguard Worker  * @param backends  The backends on which to run the test
149*89c4ff92SAndroid Build Coastguard Worker  */
150*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
HardSwishEndToEndTest(const std::vector<BackendId> & backends)151*89c4ff92SAndroid Build Coastguard Worker void HardSwishEndToEndTest(const std::vector<BackendId>& backends)
152*89c4ff92SAndroid Build Coastguard Worker {
153*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> floatInputData{ -2.0f, -1.0f, -0.5f, 0.0f,
154*89c4ff92SAndroid Build Coastguard Worker                                        1.0f,  2.0f,  3.0f, 4.0f };
155*89c4ff92SAndroid Build Coastguard Worker 
156*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> floatExpectedOutputData{ -0.33333333333f,  -0.33333333333f, -0.208333f, 0.0f,
157*89c4ff92SAndroid Build Coastguard Worker                                                  0.66666666667f,   1.66666666667f,  3.0f     , 4.0f };
158*89c4ff92SAndroid Build Coastguard Worker 
159*89c4ff92SAndroid Build Coastguard Worker     float qScale = 1.0f;
160*89c4ff92SAndroid Build Coastguard Worker     int32_t qOffset = 0;
161*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputInfo({ 2, 2, 2, 1 }, ArmnnType, qScale, qOffset, true);
162*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo({ 2, 2, 2, 1 }, ArmnnType, qScale, qOffset);
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker     armnn::ActivationDescriptor descriptor(ActivationFunction::HardSwish, 1.0);
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker     ActivationEndToEndImpl<ArmnnType>(backends,
167*89c4ff92SAndroid Build Coastguard Worker                                       floatInputData,
168*89c4ff92SAndroid Build Coastguard Worker                                       floatExpectedOutputData,
169*89c4ff92SAndroid Build Coastguard Worker                                       inputInfo,
170*89c4ff92SAndroid Build Coastguard Worker                                       outputInfo,
171*89c4ff92SAndroid Build Coastguard Worker                                       descriptor);
172*89c4ff92SAndroid Build Coastguard Worker }
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace