xref: /aosp_15_r20/external/armnn/src/backends/reference/test/RefPerAxisIteratorTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <reference/workloads/Decoders.hpp>
7 
8 #include <fmt/format.h>
9 
10 #include <doctest/doctest.h>
11 
12 #include <chrono>
13 
14 template<typename T>
CompareVector(std::vector<T> vec1,std::vector<T> vec2)15 void CompareVector(std::vector<T> vec1, std::vector<T> vec2)
16 {
17     CHECK(vec1.size() == vec2.size());
18 
19     bool mismatch = false;
20     for (uint32_t i = 0; i < vec1.size(); ++i)
21     {
22         if (vec1[i] != vec2[i])
23         {
24             MESSAGE(fmt::format("Vector value mismatch: index={}  {} != {}",
25                                 i,
26                                 vec1[i],
27                                 vec2[i]));
28 
29             mismatch = true;
30         }
31     }
32 
33     if (mismatch)
34     {
35         FAIL("Error in CompareVector. Vectors don't match.");
36     }
37 }
38 
39 using namespace armnn;
40 
41 // Basically a per axis decoder but without any decoding/quantization
42 class MockPerAxisIterator : public PerAxisIterator<const int8_t, Decoder<int8_t>>
43 {
44 public:
MockPerAxisIterator(const int8_t * data,const armnn::TensorShape & tensorShape,const unsigned int axis)45     MockPerAxisIterator(const int8_t* data, const armnn::TensorShape& tensorShape, const unsigned int axis)
46             : PerAxisIterator(data, tensorShape, axis), m_NumElements(tensorShape.GetNumElements())
47     {}
48 
Get() const49     int8_t Get() const override
50     {
51         return *m_Iterator;
52     }
53 
DecodeTensor(const TensorShape & tensorShape,bool isDepthwise=false)54     virtual std::vector<float> DecodeTensor(const TensorShape &tensorShape,
55                                             bool isDepthwise = false) override
56     {
57         IgnoreUnused(tensorShape, isDepthwise);
58         return std::vector<float>{};
59     };
60 
61     // Iterates over data using operator[] and returns vector
Loop()62     std::vector<int8_t> Loop()
63     {
64         std::vector<int8_t> vec;
65         for (uint32_t i = 0; i < m_NumElements; ++i)
66         {
67             this->operator[](i);
68             vec.emplace_back(Get());
69         }
70         return vec;
71     }
72 
GetAxisIndex()73     unsigned int GetAxisIndex()
74     {
75         return m_AxisIndex;
76     }
77     unsigned int m_NumElements;
78 };
79 
80 TEST_SUITE("RefPerAxisIterator")
81 {
82 // Test Loop (Equivalent to DecodeTensor) and Axis = 0
83 TEST_CASE("PerAxisIteratorTest1")
84 {
85     std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
86     TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
87 
88     // test axis=0
89     std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
90     auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 0);
91     std::vector<int8_t> output = iterator.Loop();
92     CompareVector(output, expOutput);
93 
94     // Set iterator to index and check if the axis index is correct
95     iterator[5];
96     CHECK(iterator.GetAxisIndex() == 1u);
97 
98     iterator[1];
99     CHECK(iterator.GetAxisIndex() == 0u);
100 
101     iterator[10];
102     CHECK(iterator.GetAxisIndex() == 2u);
103 }
104 
105 // Test Axis = 1
106 TEST_CASE("PerAxisIteratorTest2")
107 {
108     std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
109     TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
110 
111     // test axis=1
112     std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
113     auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
114     std::vector<int8_t> output = iterator.Loop();
115     CompareVector(output, expOutput);
116 
117     // Set iterator to index and check if the axis index is correct
118     iterator[5];
119     CHECK(iterator.GetAxisIndex() == 0u);
120 
121     iterator[1];
122     CHECK(iterator.GetAxisIndex() == 0u);
123 
124     iterator[10];
125     CHECK(iterator.GetAxisIndex() == 0u);
126 }
127 
128 // Test Axis = 2
129 TEST_CASE("PerAxisIteratorTest3")
130 {
131     std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
132     TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
133 
134     // test axis=2
135     std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
136     auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
137     std::vector<int8_t> output = iterator.Loop();
138     CompareVector(output, expOutput);
139 
140     // Set iterator to index and check if the axis index is correct
141     iterator[5];
142     CHECK(iterator.GetAxisIndex() == 0u);
143 
144     iterator[1];
145     CHECK(iterator.GetAxisIndex() == 0u);
146 
147     iterator[10];
148     CHECK(iterator.GetAxisIndex() == 1u);
149 }
150 
151 // Test Axis = 3
152 TEST_CASE("PerAxisIteratorTest4")
153 {
154     std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
155     TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
156 
157     // test axis=3
158     std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
159     auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 3);
160     std::vector<int8_t> output = iterator.Loop();
161     CompareVector(output, expOutput);
162 
163     // Set iterator to index and check if the axis index is correct
164     iterator[5];
165     CHECK(iterator.GetAxisIndex() == 1u);
166 
167     iterator[1];
168     CHECK(iterator.GetAxisIndex() == 1u);
169 
170     iterator[10];
171     CHECK(iterator.GetAxisIndex() == 0u);
172 }
173 
174 // Test Axis = 1. Different tensor shape
175 TEST_CASE("PerAxisIteratorTest5")
176 {
177     using namespace armnn;
178     std::vector<int8_t> input =
179     {
180          0,  1,  2,  3,
181          4,  5,  6,  7,
182          8,  9, 10, 11,
183         12, 13, 14, 15
184     };
185 
186     std::vector<int8_t> expOutput =
187     {
188          0,  1,  2,  3,
189          4,  5,  6,  7,
190          8,  9, 10, 11,
191         12, 13, 14, 15
192     };
193 
194     TensorInfo tensorInfo ({2,2,2,2},DataType::QSymmS8);
195     auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
196     std::vector<int8_t> output = iterator.Loop();
197     CompareVector(output, expOutput);
198 
199     // Set iterator to index and check if the axis index is correct
200     iterator[5];
201     CHECK(iterator.GetAxisIndex() == 1u);
202 
203     iterator[1];
204     CHECK(iterator.GetAxisIndex() == 0u);
205 
206     iterator[10];
207     CHECK(iterator.GetAxisIndex() == 0u);
208 }
209 
210 // Test the increment and decrement operator
211 TEST_CASE("PerAxisIteratorTest7")
212 {
213     using namespace armnn;
214     std::vector<int8_t> input =
215     {
216         0, 1,  2,  3,
217         4, 5,  6,  7,
218         8, 9, 10, 11
219     };
220 
221     std::vector<int8_t> expOutput =
222     {
223         0, 1,  2,  3,
224         4, 5,  6,  7,
225         8, 9, 10, 11
226     };
227 
228     TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
229     auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
230 
231     iterator += 3;
232     CHECK(iterator.Get() == expOutput[3]);
233     CHECK(iterator.GetAxisIndex() == 1u);
234 
235     iterator += 3;
236     CHECK(iterator.Get() == expOutput[6]);
237     CHECK(iterator.GetAxisIndex() == 1u);
238 
239     iterator -= 2;
240     CHECK(iterator.Get() == expOutput[4]);
241     CHECK(iterator.GetAxisIndex() == 0u);
242 
243     iterator -= 1;
244     CHECK(iterator.Get() == expOutput[3]);
245     CHECK(iterator.GetAxisIndex() == 1u);
246 }
247 
248 }