xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/DequantizeEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker template<typename T>
CreateDequantizeNetwork(const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputInfo)19*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateDequantizeNetwork(const armnn::TensorInfo& inputInfo,
20*89c4ff92SAndroid Build Coastguard Worker                                            const armnn::TensorInfo& outputInfo)
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr net(armnn::INetwork::Create());
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* inputLayer = net->AddInputLayer(0);
25*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* dequantizeLayer = net->AddDequantizeLayer("Dequantize");
26*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
27*89c4ff92SAndroid Build Coastguard Worker     Connect(inputLayer, dequantizeLayer, inputInfo, 0, 0);
28*89c4ff92SAndroid Build Coastguard Worker     Connect(dequantizeLayer, outputLayer, outputInfo, 0, 0);
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker     return net;
31*89c4ff92SAndroid Build Coastguard Worker }
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
DequantizeEndToEndLayerTestImpl(const std::vector<BackendId> & backends,const armnn::TensorShape & tensorShape,const std::vector<T> & input,const std::vector<float> & expectedOutput,float scale,int32_t offset)34*89c4ff92SAndroid Build Coastguard Worker void DequantizeEndToEndLayerTestImpl(const std::vector<BackendId>& backends,
35*89c4ff92SAndroid Build Coastguard Worker                                      const armnn::TensorShape& tensorShape,
36*89c4ff92SAndroid Build Coastguard Worker                                      const std::vector<T>& input,
37*89c4ff92SAndroid Build Coastguard Worker                                      const std::vector<float>& expectedOutput,
38*89c4ff92SAndroid Build Coastguard Worker                                      float scale,
39*89c4ff92SAndroid Build Coastguard Worker                                      int32_t offset)
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputInfo(tensorShape, ArmnnType);
42*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo(tensorShape, armnn::DataType::Float32);
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     inputInfo.SetQuantizationScale(scale);
45*89c4ff92SAndroid Build Coastguard Worker     inputInfo.SetQuantizationOffset(offset);
46*89c4ff92SAndroid Build Coastguard Worker     inputInfo.SetConstant(true);
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network
49*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr net = CreateDequantizeNetwork<T>(inputInfo, outputInfo);
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     CHECK(net);
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<T>> inputTensorData = { { 0, input } };
54*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<float>> expectedOutputData = { { 0, expectedOutput } };
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, armnn::DataType::Float32>(
57*89c4ff92SAndroid Build Coastguard Worker             move(net), inputTensorData, expectedOutputData, backends);
58*89c4ff92SAndroid Build Coastguard Worker }
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
DequantizeEndToEndSimple(const std::vector<BackendId> & backends)61*89c4ff92SAndroid Build Coastguard Worker void DequantizeEndToEndSimple(const std::vector<BackendId>& backends)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape tensorShape({ 1, 2, 2, 4 });
64*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputData = std::vector<T>(
65*89c4ff92SAndroid Build Coastguard Worker     {
66*89c4ff92SAndroid Build Coastguard Worker         2, 4, 6, 8,
67*89c4ff92SAndroid Build Coastguard Worker         10, 12, 14, 16,
68*89c4ff92SAndroid Build Coastguard Worker         18, 20, 22, 24,
69*89c4ff92SAndroid Build Coastguard Worker         26, 28, 30, 32
70*89c4ff92SAndroid Build Coastguard Worker     });
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData = std::vector<float>(
73*89c4ff92SAndroid Build Coastguard Worker     {
74*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f,
75*89c4ff92SAndroid Build Coastguard Worker         5.0f, 6.0f, 7.0f,  8.0f,
76*89c4ff92SAndroid Build Coastguard Worker         9.0f, 10.0f, 11.0f, 12.0f,
77*89c4ff92SAndroid Build Coastguard Worker         13.0f, 14.0f, 15.0f, 16.0f
78*89c4ff92SAndroid Build Coastguard Worker     });
79*89c4ff92SAndroid Build Coastguard Worker     DequantizeEndToEndLayerTestImpl<ArmnnType>(backends, tensorShape, inputData, expectedOutputData, 0.5f, 0);
80*89c4ff92SAndroid Build Coastguard Worker };
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
DequantizeEndToEndOffset(const std::vector<BackendId> & backends)83*89c4ff92SAndroid Build Coastguard Worker void DequantizeEndToEndOffset(const std::vector<BackendId>& backends)
84*89c4ff92SAndroid Build Coastguard Worker {
85*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape tensorShape({ 1, 2, 2, 4 });
86*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputData = std::vector<T>(
87*89c4ff92SAndroid Build Coastguard Worker     {
88*89c4ff92SAndroid Build Coastguard Worker         3, 5, 7, 9,
89*89c4ff92SAndroid Build Coastguard Worker         11, 13, 15, 17,
90*89c4ff92SAndroid Build Coastguard Worker         19, 21, 23, 25,
91*89c4ff92SAndroid Build Coastguard Worker         27, 29, 31, 33
92*89c4ff92SAndroid Build Coastguard Worker     });
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData = std::vector<float>(
95*89c4ff92SAndroid Build Coastguard Worker     {
96*89c4ff92SAndroid Build Coastguard Worker         1.0f, 2.0f, 3.0f, 4.0f,
97*89c4ff92SAndroid Build Coastguard Worker         5.0f, 6.0f, 7.0f,  8.0f,
98*89c4ff92SAndroid Build Coastguard Worker         9.0f, 10.0f, 11.0f, 12.0f,
99*89c4ff92SAndroid Build Coastguard Worker         13.0f, 14.0f, 15.0f, 16.0f
100*89c4ff92SAndroid Build Coastguard Worker     });
101*89c4ff92SAndroid Build Coastguard Worker     DequantizeEndToEndLayerTestImpl<ArmnnType>(backends, tensorShape, inputData, expectedOutputData, 0.5f, 1);
102*89c4ff92SAndroid Build Coastguard Worker };
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
105