xref: /aosp_15_r20/external/armnn/src/backends/reference/test/RefPerChannelDecoderTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <reference/workloads/Decoders.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("RefPerChannelDecoder")
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker template<typename T>
CompareVector(std::vector<T> vec1,std::vector<T> vec2)15*89c4ff92SAndroid Build Coastguard Worker void CompareVector(std::vector<T> vec1, std::vector<T> vec2)
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker     CHECK(vec1.size() == vec2.size());
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker     bool mismatch = false;
20*89c4ff92SAndroid Build Coastguard Worker     for (uint32_t i = 0; i < vec1.size(); ++i)
21*89c4ff92SAndroid Build Coastguard Worker     {
22*89c4ff92SAndroid Build Coastguard Worker         if (vec1[i] != vec2[i])
23*89c4ff92SAndroid Build Coastguard Worker         {
24*89c4ff92SAndroid Build Coastguard Worker             MESSAGE(fmt::format("Vector value mismatch: index={}  {} != {}",
25*89c4ff92SAndroid Build Coastguard Worker                                 i,
26*89c4ff92SAndroid Build Coastguard Worker                                 vec1[i],
27*89c4ff92SAndroid Build Coastguard Worker                                 vec2[i]));
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker             mismatch = true;
30*89c4ff92SAndroid Build Coastguard Worker         }
31*89c4ff92SAndroid Build Coastguard Worker     }
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     if (mismatch)
34*89c4ff92SAndroid Build Coastguard Worker     {
35*89c4ff92SAndroid Build Coastguard Worker         FAIL("Error in CompareVector. Vectors don't match.");
36*89c4ff92SAndroid Build Coastguard Worker     }
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker // Ensure quantization works for none depthwise convolutions
40*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("RefPerChannelDecoderTest1")
41*89c4ff92SAndroid Build Coastguard Worker {
42*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
43*89c4ff92SAndroid Build Coastguard Worker     std::vector<int8_t> input =
44*89c4ff92SAndroid Build Coastguard Worker     {
45*89c4ff92SAndroid Build Coastguard Worker         0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23
46*89c4ff92SAndroid Build Coastguard Worker     };
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expOutput =
49*89c4ff92SAndroid Build Coastguard Worker     {
50*89c4ff92SAndroid Build Coastguard Worker         0.0f,   1.0f,  2.0f,  3.0f,  4.0f,  5.0f,  6.0f,  7.0f,  8.0f,  9.0f, 10.0f, 11.0f,
51*89c4ff92SAndroid Build Coastguard Worker         24.0f, 26.0f, 28.0f, 30.0f, 32.0f, 34.0f, 36.0f, 38.0f, 40.0f, 42.0f, 44.0f, 46.0f
52*89c4ff92SAndroid Build Coastguard Worker     };
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     TensorInfo tensorInfo ({2,2,2,3},DataType::QSymmS8,{1.0f, 2.0f},0);
55*89c4ff92SAndroid Build Coastguard Worker     auto decoder = MakeDecoder<float>(tensorInfo, input.data());
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape());
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     CompareVector(output, expOutput);
60*89c4ff92SAndroid Build Coastguard Worker }
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker // Ensure quantization works for depthwise convolutions M=1
63*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("RefPerChannelDecoderTest2")
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
66*89c4ff92SAndroid Build Coastguard Worker     std::vector<int8_t> input =
67*89c4ff92SAndroid Build Coastguard Worker     {
68*89c4ff92SAndroid Build Coastguard Worker         0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
69*89c4ff92SAndroid Build Coastguard Worker     };
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expOutput =
72*89c4ff92SAndroid Build Coastguard Worker     {
73*89c4ff92SAndroid Build Coastguard Worker          0.0f,  1.0f,  2.0f,  3.0f,
74*89c4ff92SAndroid Build Coastguard Worker          8.0f, 10.0f, 12.0f, 14.0f,
75*89c4ff92SAndroid Build Coastguard Worker         24.0f, 27.0f, 30.0f, 33.0f,
76*89c4ff92SAndroid Build Coastguard Worker         48.0f, 52.0f, 56.0f, 60.0f
77*89c4ff92SAndroid Build Coastguard Worker     };
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker     // [O,1,H,W] = [I*M,1,H,W] = [4*1,1,2,2]
80*89c4ff92SAndroid Build Coastguard Worker     TensorInfo tensorInfo ({4,1,2,2},DataType::QSymmS8,{1.0f, 2.0f, 3.0f, 4.0f},0);
81*89c4ff92SAndroid Build Coastguard Worker     auto decoder = MakeDecoder<float>(tensorInfo, input.data());
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape(), true);
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker     CompareVector(output, expOutput);
86*89c4ff92SAndroid Build Coastguard Worker }
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker // Ensure quantization works for depthwise convolutions M=2
89*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("RefPerChannelDecoderTest3")
90*89c4ff92SAndroid Build Coastguard Worker {
91*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
92*89c4ff92SAndroid Build Coastguard Worker     std::vector<int8_t> input =
93*89c4ff92SAndroid Build Coastguard Worker     {
94*89c4ff92SAndroid Build Coastguard Worker         0, 1, 2, 3,
95*89c4ff92SAndroid Build Coastguard Worker         4, 5, 6, 7,
96*89c4ff92SAndroid Build Coastguard Worker         8, 9, 10, 11,
97*89c4ff92SAndroid Build Coastguard Worker         12, 13, 14, 15,
98*89c4ff92SAndroid Build Coastguard Worker         16, 17, 18, 19,
99*89c4ff92SAndroid Build Coastguard Worker         20, 21, 22, 23
100*89c4ff92SAndroid Build Coastguard Worker     };
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expOutput =
103*89c4ff92SAndroid Build Coastguard Worker     {
104*89c4ff92SAndroid Build Coastguard Worker          0.0f,  1.0f,  2.0f,  3.0f,
105*89c4ff92SAndroid Build Coastguard Worker          8.0f, 10.0f, 12.0f, 14.0f,
106*89c4ff92SAndroid Build Coastguard Worker         24.0f, 27.0f, 30.0f, 33.0f,
107*89c4ff92SAndroid Build Coastguard Worker         48.0f, 52.0f, 56.0f, 60.0f,
108*89c4ff92SAndroid Build Coastguard Worker         80.0f, 85.0f, 90.0f, 95.0f,
109*89c4ff92SAndroid Build Coastguard Worker         120.0f, 126.0f, 132.0f, 138.0f
110*89c4ff92SAndroid Build Coastguard Worker     };
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker     // [O,1,H,W] = [I*M,1,H,W] = [3*2,1,2,2]
113*89c4ff92SAndroid Build Coastguard Worker     TensorInfo tensorInfo ({6,1,2,2},DataType::QSymmS8,{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},0);
114*89c4ff92SAndroid Build Coastguard Worker     auto decoder = MakeDecoder<float>(tensorInfo, input.data());
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape(), true);
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker     CompareVector(output, expOutput);
119*89c4ff92SAndroid Build Coastguard Worker }
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker // Ensure quantization works for depthwise convolutions M=2 for int32
122*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("RefPerChannelDecoderTest4")
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
125*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input =
126*89c4ff92SAndroid Build Coastguard Worker     {
127*89c4ff92SAndroid Build Coastguard Worker         0, 1, 2, 3,
128*89c4ff92SAndroid Build Coastguard Worker         4, 5, 6, 7,
129*89c4ff92SAndroid Build Coastguard Worker         8, 9, 10, 11,
130*89c4ff92SAndroid Build Coastguard Worker         12, 13, 14, 15,
131*89c4ff92SAndroid Build Coastguard Worker         16, 17, 18, 19,
132*89c4ff92SAndroid Build Coastguard Worker         20, 21, 22, 23
133*89c4ff92SAndroid Build Coastguard Worker     };
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expOutput =
136*89c4ff92SAndroid Build Coastguard Worker     {
137*89c4ff92SAndroid Build Coastguard Worker          0.0f,  1.0f,  2.0f,  3.0f,
138*89c4ff92SAndroid Build Coastguard Worker          8.0f, 10.0f, 12.0f, 14.0f,
139*89c4ff92SAndroid Build Coastguard Worker         24.0f, 27.0f, 30.0f, 33.0f,
140*89c4ff92SAndroid Build Coastguard Worker         48.0f, 52.0f, 56.0f, 60.0f,
141*89c4ff92SAndroid Build Coastguard Worker         80.0f, 85.0f, 90.0f, 95.0f,
142*89c4ff92SAndroid Build Coastguard Worker         120.0f, 126.0f, 132.0f, 138.0f
143*89c4ff92SAndroid Build Coastguard Worker     };
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker     // [O,1,H,W] = [I*M,1,H,W] = [3*2,1,2,2]
146*89c4ff92SAndroid Build Coastguard Worker     TensorInfo tensorInfo ({6,1,2,2},DataType::Signed32,{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},0);
147*89c4ff92SAndroid Build Coastguard Worker     auto decoder = MakeDecoder<float>(tensorInfo, input.data());
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape(), true);
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker     CompareVector(output, expOutput);
152*89c4ff92SAndroid Build Coastguard Worker }
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker }
155