1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <reference/workloads/Decoders.hpp>
7
8 #include <fmt/format.h>
9
10 #include <doctest/doctest.h>
11
12 #include <chrono>
13
14 template<typename T>
CompareVector(std::vector<T> vec1,std::vector<T> vec2)15 void CompareVector(std::vector<T> vec1, std::vector<T> vec2)
16 {
17 CHECK(vec1.size() == vec2.size());
18
19 bool mismatch = false;
20 for (uint32_t i = 0; i < vec1.size(); ++i)
21 {
22 if (vec1[i] != vec2[i])
23 {
24 MESSAGE(fmt::format("Vector value mismatch: index={} {} != {}",
25 i,
26 vec1[i],
27 vec2[i]));
28
29 mismatch = true;
30 }
31 }
32
33 if (mismatch)
34 {
35 FAIL("Error in CompareVector. Vectors don't match.");
36 }
37 }
38
39 using namespace armnn;
40
41 // Basically a per axis decoder but without any decoding/quantization
42 class MockPerAxisIterator : public PerAxisIterator<const int8_t, Decoder<int8_t>>
43 {
44 public:
MockPerAxisIterator(const int8_t * data,const armnn::TensorShape & tensorShape,const unsigned int axis)45 MockPerAxisIterator(const int8_t* data, const armnn::TensorShape& tensorShape, const unsigned int axis)
46 : PerAxisIterator(data, tensorShape, axis), m_NumElements(tensorShape.GetNumElements())
47 {}
48
Get() const49 int8_t Get() const override
50 {
51 return *m_Iterator;
52 }
53
DecodeTensor(const TensorShape & tensorShape,bool isDepthwise=false)54 virtual std::vector<float> DecodeTensor(const TensorShape &tensorShape,
55 bool isDepthwise = false) override
56 {
57 IgnoreUnused(tensorShape, isDepthwise);
58 return std::vector<float>{};
59 };
60
61 // Iterates over data using operator[] and returns vector
Loop()62 std::vector<int8_t> Loop()
63 {
64 std::vector<int8_t> vec;
65 for (uint32_t i = 0; i < m_NumElements; ++i)
66 {
67 this->operator[](i);
68 vec.emplace_back(Get());
69 }
70 return vec;
71 }
72
GetAxisIndex()73 unsigned int GetAxisIndex()
74 {
75 return m_AxisIndex;
76 }
77 unsigned int m_NumElements;
78 };
79
80 TEST_SUITE("RefPerAxisIterator")
81 {
82 // Test Loop (Equivalent to DecodeTensor) and Axis = 0
83 TEST_CASE("PerAxisIteratorTest1")
84 {
85 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
86 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
87
88 // test axis=0
89 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
90 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 0);
91 std::vector<int8_t> output = iterator.Loop();
92 CompareVector(output, expOutput);
93
94 // Set iterator to index and check if the axis index is correct
95 iterator[5];
96 CHECK(iterator.GetAxisIndex() == 1u);
97
98 iterator[1];
99 CHECK(iterator.GetAxisIndex() == 0u);
100
101 iterator[10];
102 CHECK(iterator.GetAxisIndex() == 2u);
103 }
104
105 // Test Axis = 1
106 TEST_CASE("PerAxisIteratorTest2")
107 {
108 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
109 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
110
111 // test axis=1
112 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
113 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
114 std::vector<int8_t> output = iterator.Loop();
115 CompareVector(output, expOutput);
116
117 // Set iterator to index and check if the axis index is correct
118 iterator[5];
119 CHECK(iterator.GetAxisIndex() == 0u);
120
121 iterator[1];
122 CHECK(iterator.GetAxisIndex() == 0u);
123
124 iterator[10];
125 CHECK(iterator.GetAxisIndex() == 0u);
126 }
127
128 // Test Axis = 2
129 TEST_CASE("PerAxisIteratorTest3")
130 {
131 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
132 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
133
134 // test axis=2
135 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
136 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
137 std::vector<int8_t> output = iterator.Loop();
138 CompareVector(output, expOutput);
139
140 // Set iterator to index and check if the axis index is correct
141 iterator[5];
142 CHECK(iterator.GetAxisIndex() == 0u);
143
144 iterator[1];
145 CHECK(iterator.GetAxisIndex() == 0u);
146
147 iterator[10];
148 CHECK(iterator.GetAxisIndex() == 1u);
149 }
150
151 // Test Axis = 3
152 TEST_CASE("PerAxisIteratorTest4")
153 {
154 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
155 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
156
157 // test axis=3
158 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
159 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 3);
160 std::vector<int8_t> output = iterator.Loop();
161 CompareVector(output, expOutput);
162
163 // Set iterator to index and check if the axis index is correct
164 iterator[5];
165 CHECK(iterator.GetAxisIndex() == 1u);
166
167 iterator[1];
168 CHECK(iterator.GetAxisIndex() == 1u);
169
170 iterator[10];
171 CHECK(iterator.GetAxisIndex() == 0u);
172 }
173
174 // Test Axis = 1. Different tensor shape
175 TEST_CASE("PerAxisIteratorTest5")
176 {
177 using namespace armnn;
178 std::vector<int8_t> input =
179 {
180 0, 1, 2, 3,
181 4, 5, 6, 7,
182 8, 9, 10, 11,
183 12, 13, 14, 15
184 };
185
186 std::vector<int8_t> expOutput =
187 {
188 0, 1, 2, 3,
189 4, 5, 6, 7,
190 8, 9, 10, 11,
191 12, 13, 14, 15
192 };
193
194 TensorInfo tensorInfo ({2,2,2,2},DataType::QSymmS8);
195 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
196 std::vector<int8_t> output = iterator.Loop();
197 CompareVector(output, expOutput);
198
199 // Set iterator to index and check if the axis index is correct
200 iterator[5];
201 CHECK(iterator.GetAxisIndex() == 1u);
202
203 iterator[1];
204 CHECK(iterator.GetAxisIndex() == 0u);
205
206 iterator[10];
207 CHECK(iterator.GetAxisIndex() == 0u);
208 }
209
210 // Test the increment and decrement operator
211 TEST_CASE("PerAxisIteratorTest7")
212 {
213 using namespace armnn;
214 std::vector<int8_t> input =
215 {
216 0, 1, 2, 3,
217 4, 5, 6, 7,
218 8, 9, 10, 11
219 };
220
221 std::vector<int8_t> expOutput =
222 {
223 0, 1, 2, 3,
224 4, 5, 6, 7,
225 8, 9, 10, 11
226 };
227
228 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
229 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
230
231 iterator += 3;
232 CHECK(iterator.Get() == expOutput[3]);
233 CHECK(iterator.GetAxisIndex() == 1u);
234
235 iterator += 3;
236 CHECK(iterator.Get() == expOutput[6]);
237 CHECK(iterator.GetAxisIndex() == 1u);
238
239 iterator -= 2;
240 CHECK(iterator.Get() == expOutput[4]);
241 CHECK(iterator.GetAxisIndex() == 0u);
242
243 iterator -= 1;
244 CHECK(iterator.Get() == expOutput[3]);
245 CHECK(iterator.GetAxisIndex() == 1u);
246 }
247
248 }