1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <reference/workloads/Decoders.hpp>
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("RefPerChannelDecoder")
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker template<typename T>
CompareVector(std::vector<T> vec1,std::vector<T> vec2)15*89c4ff92SAndroid Build Coastguard Worker void CompareVector(std::vector<T> vec1, std::vector<T> vec2)
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker CHECK(vec1.size() == vec2.size());
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker bool mismatch = false;
20*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < vec1.size(); ++i)
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker if (vec1[i] != vec2[i])
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker MESSAGE(fmt::format("Vector value mismatch: index={} {} != {}",
25*89c4ff92SAndroid Build Coastguard Worker i,
26*89c4ff92SAndroid Build Coastguard Worker vec1[i],
27*89c4ff92SAndroid Build Coastguard Worker vec2[i]));
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker mismatch = true;
30*89c4ff92SAndroid Build Coastguard Worker }
31*89c4ff92SAndroid Build Coastguard Worker }
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker if (mismatch)
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker FAIL("Error in CompareVector. Vectors don't match.");
36*89c4ff92SAndroid Build Coastguard Worker }
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker
39*89c4ff92SAndroid Build Coastguard Worker // Ensure quantization works for none depthwise convolutions
40*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("RefPerChannelDecoderTest1")
41*89c4ff92SAndroid Build Coastguard Worker {
42*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
43*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> input =
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23
46*89c4ff92SAndroid Build Coastguard Worker };
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expOutput =
49*89c4ff92SAndroid Build Coastguard Worker {
50*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f,
51*89c4ff92SAndroid Build Coastguard Worker 24.0f, 26.0f, 28.0f, 30.0f, 32.0f, 34.0f, 36.0f, 38.0f, 40.0f, 42.0f, 44.0f, 46.0f
52*89c4ff92SAndroid Build Coastguard Worker };
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker TensorInfo tensorInfo ({2,2,2,3},DataType::QSymmS8,{1.0f, 2.0f},0);
55*89c4ff92SAndroid Build Coastguard Worker auto decoder = MakeDecoder<float>(tensorInfo, input.data());
56*89c4ff92SAndroid Build Coastguard Worker
57*89c4ff92SAndroid Build Coastguard Worker std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape());
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker CompareVector(output, expOutput);
60*89c4ff92SAndroid Build Coastguard Worker }
61*89c4ff92SAndroid Build Coastguard Worker
62*89c4ff92SAndroid Build Coastguard Worker // Ensure quantization works for depthwise convolutions M=1
63*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("RefPerChannelDecoderTest2")
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
66*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> input =
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
69*89c4ff92SAndroid Build Coastguard Worker };
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expOutput =
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 2.0f, 3.0f,
74*89c4ff92SAndroid Build Coastguard Worker 8.0f, 10.0f, 12.0f, 14.0f,
75*89c4ff92SAndroid Build Coastguard Worker 24.0f, 27.0f, 30.0f, 33.0f,
76*89c4ff92SAndroid Build Coastguard Worker 48.0f, 52.0f, 56.0f, 60.0f
77*89c4ff92SAndroid Build Coastguard Worker };
78*89c4ff92SAndroid Build Coastguard Worker
79*89c4ff92SAndroid Build Coastguard Worker // [O,1,H,W] = [I*M,1,H,W] = [4*1,1,2,2]
80*89c4ff92SAndroid Build Coastguard Worker TensorInfo tensorInfo ({4,1,2,2},DataType::QSymmS8,{1.0f, 2.0f, 3.0f, 4.0f},0);
81*89c4ff92SAndroid Build Coastguard Worker auto decoder = MakeDecoder<float>(tensorInfo, input.data());
82*89c4ff92SAndroid Build Coastguard Worker
83*89c4ff92SAndroid Build Coastguard Worker std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape(), true);
84*89c4ff92SAndroid Build Coastguard Worker
85*89c4ff92SAndroid Build Coastguard Worker CompareVector(output, expOutput);
86*89c4ff92SAndroid Build Coastguard Worker }
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker // Ensure quantization works for depthwise convolutions M=2
89*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("RefPerChannelDecoderTest3")
90*89c4ff92SAndroid Build Coastguard Worker {
91*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
92*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> input =
93*89c4ff92SAndroid Build Coastguard Worker {
94*89c4ff92SAndroid Build Coastguard Worker 0, 1, 2, 3,
95*89c4ff92SAndroid Build Coastguard Worker 4, 5, 6, 7,
96*89c4ff92SAndroid Build Coastguard Worker 8, 9, 10, 11,
97*89c4ff92SAndroid Build Coastguard Worker 12, 13, 14, 15,
98*89c4ff92SAndroid Build Coastguard Worker 16, 17, 18, 19,
99*89c4ff92SAndroid Build Coastguard Worker 20, 21, 22, 23
100*89c4ff92SAndroid Build Coastguard Worker };
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expOutput =
103*89c4ff92SAndroid Build Coastguard Worker {
104*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 2.0f, 3.0f,
105*89c4ff92SAndroid Build Coastguard Worker 8.0f, 10.0f, 12.0f, 14.0f,
106*89c4ff92SAndroid Build Coastguard Worker 24.0f, 27.0f, 30.0f, 33.0f,
107*89c4ff92SAndroid Build Coastguard Worker 48.0f, 52.0f, 56.0f, 60.0f,
108*89c4ff92SAndroid Build Coastguard Worker 80.0f, 85.0f, 90.0f, 95.0f,
109*89c4ff92SAndroid Build Coastguard Worker 120.0f, 126.0f, 132.0f, 138.0f
110*89c4ff92SAndroid Build Coastguard Worker };
111*89c4ff92SAndroid Build Coastguard Worker
112*89c4ff92SAndroid Build Coastguard Worker // [O,1,H,W] = [I*M,1,H,W] = [3*2,1,2,2]
113*89c4ff92SAndroid Build Coastguard Worker TensorInfo tensorInfo ({6,1,2,2},DataType::QSymmS8,{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},0);
114*89c4ff92SAndroid Build Coastguard Worker auto decoder = MakeDecoder<float>(tensorInfo, input.data());
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape(), true);
117*89c4ff92SAndroid Build Coastguard Worker
118*89c4ff92SAndroid Build Coastguard Worker CompareVector(output, expOutput);
119*89c4ff92SAndroid Build Coastguard Worker }
120*89c4ff92SAndroid Build Coastguard Worker
121*89c4ff92SAndroid Build Coastguard Worker // Ensure quantization works for depthwise convolutions M=2 for int32
122*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("RefPerChannelDecoderTest4")
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
125*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> input =
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker 0, 1, 2, 3,
128*89c4ff92SAndroid Build Coastguard Worker 4, 5, 6, 7,
129*89c4ff92SAndroid Build Coastguard Worker 8, 9, 10, 11,
130*89c4ff92SAndroid Build Coastguard Worker 12, 13, 14, 15,
131*89c4ff92SAndroid Build Coastguard Worker 16, 17, 18, 19,
132*89c4ff92SAndroid Build Coastguard Worker 20, 21, 22, 23
133*89c4ff92SAndroid Build Coastguard Worker };
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expOutput =
136*89c4ff92SAndroid Build Coastguard Worker {
137*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 2.0f, 3.0f,
138*89c4ff92SAndroid Build Coastguard Worker 8.0f, 10.0f, 12.0f, 14.0f,
139*89c4ff92SAndroid Build Coastguard Worker 24.0f, 27.0f, 30.0f, 33.0f,
140*89c4ff92SAndroid Build Coastguard Worker 48.0f, 52.0f, 56.0f, 60.0f,
141*89c4ff92SAndroid Build Coastguard Worker 80.0f, 85.0f, 90.0f, 95.0f,
142*89c4ff92SAndroid Build Coastguard Worker 120.0f, 126.0f, 132.0f, 138.0f
143*89c4ff92SAndroid Build Coastguard Worker };
144*89c4ff92SAndroid Build Coastguard Worker
145*89c4ff92SAndroid Build Coastguard Worker // [O,1,H,W] = [I*M,1,H,W] = [3*2,1,2,2]
146*89c4ff92SAndroid Build Coastguard Worker TensorInfo tensorInfo ({6,1,2,2},DataType::Signed32,{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},0);
147*89c4ff92SAndroid Build Coastguard Worker auto decoder = MakeDecoder<float>(tensorInfo, input.data());
148*89c4ff92SAndroid Build Coastguard Worker
149*89c4ff92SAndroid Build Coastguard Worker std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape(), true);
150*89c4ff92SAndroid Build Coastguard Worker
151*89c4ff92SAndroid Build Coastguard Worker CompareVector(output, expOutput);
152*89c4ff92SAndroid Build Coastguard Worker }
153*89c4ff92SAndroid Build Coastguard Worker
154*89c4ff92SAndroid Build Coastguard Worker }
155