xref: /aosp_15_r20/external/XNNPACK/test/transpose-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2021 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <numeric>
9*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
10*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
11*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
12*4bdc9457SAndroid Build Coastguard Worker #include <cstring>
13*4bdc9457SAndroid Build Coastguard Worker #include <vector>
14*4bdc9457SAndroid Build Coastguard Worker 
15*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
16*4bdc9457SAndroid Build Coastguard Worker 
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
19*4bdc9457SAndroid Build Coastguard Worker 
20*4bdc9457SAndroid Build Coastguard Worker 
21*4bdc9457SAndroid Build Coastguard Worker class TransposeMicrokernelTester {
22*4bdc9457SAndroid Build Coastguard Worker  public:
element_size(size_t element_size)23*4bdc9457SAndroid Build Coastguard Worker   inline TransposeMicrokernelTester& element_size(size_t element_size) {
24*4bdc9457SAndroid Build Coastguard Worker     assert(element_size != 0);
25*4bdc9457SAndroid Build Coastguard Worker     this->element_size_ = element_size;
26*4bdc9457SAndroid Build Coastguard Worker     return *this;
27*4bdc9457SAndroid Build Coastguard Worker   }
28*4bdc9457SAndroid Build Coastguard Worker 
element_size()29*4bdc9457SAndroid Build Coastguard Worker   inline size_t element_size() const { return this->element_size_; }
30*4bdc9457SAndroid Build Coastguard Worker 
block_height(size_t block_height)31*4bdc9457SAndroid Build Coastguard Worker   inline TransposeMicrokernelTester& block_height(size_t block_height) {
32*4bdc9457SAndroid Build Coastguard Worker     assert(block_height != 0);
33*4bdc9457SAndroid Build Coastguard Worker     this->block_height_ = block_height;
34*4bdc9457SAndroid Build Coastguard Worker     return *this;
35*4bdc9457SAndroid Build Coastguard Worker   }
36*4bdc9457SAndroid Build Coastguard Worker 
block_height()37*4bdc9457SAndroid Build Coastguard Worker   inline size_t block_height() const { return this->block_height_; }
38*4bdc9457SAndroid Build Coastguard Worker 
block_width(size_t block_width)39*4bdc9457SAndroid Build Coastguard Worker   inline TransposeMicrokernelTester& block_width(size_t block_width) {
40*4bdc9457SAndroid Build Coastguard Worker     assert(block_width != 0);
41*4bdc9457SAndroid Build Coastguard Worker     this->block_width_ = block_width;
42*4bdc9457SAndroid Build Coastguard Worker     return *this;
43*4bdc9457SAndroid Build Coastguard Worker   }
44*4bdc9457SAndroid Build Coastguard Worker 
block_width()45*4bdc9457SAndroid Build Coastguard Worker   inline size_t block_width() const { return this->block_width_; }
46*4bdc9457SAndroid Build Coastguard Worker 
input_stride(size_t input_stride)47*4bdc9457SAndroid Build Coastguard Worker   inline TransposeMicrokernelTester& input_stride(size_t input_stride) {
48*4bdc9457SAndroid Build Coastguard Worker     this->input_stride_ = input_stride;
49*4bdc9457SAndroid Build Coastguard Worker     return *this;
50*4bdc9457SAndroid Build Coastguard Worker   }
51*4bdc9457SAndroid Build Coastguard Worker 
input_stride()52*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_stride() const { return this->input_stride_; }
53*4bdc9457SAndroid Build Coastguard Worker 
output_stride(size_t output_stride)54*4bdc9457SAndroid Build Coastguard Worker   inline TransposeMicrokernelTester& output_stride(size_t output_stride) {
55*4bdc9457SAndroid Build Coastguard Worker     this->output_stride_ = output_stride;
56*4bdc9457SAndroid Build Coastguard Worker     return *this;
57*4bdc9457SAndroid Build Coastguard Worker   }
58*4bdc9457SAndroid Build Coastguard Worker 
output_stride()59*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_stride() const { return this->output_stride_; }
60*4bdc9457SAndroid Build Coastguard Worker 
input_element_stride(size_t input_element_stride)61*4bdc9457SAndroid Build Coastguard Worker   inline TransposeMicrokernelTester& input_element_stride(size_t input_element_stride) {
62*4bdc9457SAndroid Build Coastguard Worker     assert(input_element_stride >=  element_size_);
63*4bdc9457SAndroid Build Coastguard Worker     this->input_element_stride_ = input_element_stride;
64*4bdc9457SAndroid Build Coastguard Worker     return *this;
65*4bdc9457SAndroid Build Coastguard Worker   }
66*4bdc9457SAndroid Build Coastguard Worker 
input_element_stride()67*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_element_stride() const {
68*4bdc9457SAndroid Build Coastguard Worker     if (input_element_stride_ == 0) {
69*4bdc9457SAndroid Build Coastguard Worker       return element_size_;
70*4bdc9457SAndroid Build Coastguard Worker     } else {
71*4bdc9457SAndroid Build Coastguard Worker       return input_element_stride_;
72*4bdc9457SAndroid Build Coastguard Worker     }
73*4bdc9457SAndroid Build Coastguard Worker   }
74*4bdc9457SAndroid Build Coastguard Worker 
output_element_stride(size_t output_element_stride)75*4bdc9457SAndroid Build Coastguard Worker   inline TransposeMicrokernelTester& output_element_stride(size_t output_element_stride) {
76*4bdc9457SAndroid Build Coastguard Worker     assert(output_element_stride >=  element_size_);
77*4bdc9457SAndroid Build Coastguard Worker     this->output_element_stride_ = output_element_stride;
78*4bdc9457SAndroid Build Coastguard Worker     return *this;
79*4bdc9457SAndroid Build Coastguard Worker   }
80*4bdc9457SAndroid Build Coastguard Worker 
output_element_stride()81*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_element_stride() const {
82*4bdc9457SAndroid Build Coastguard Worker     if (output_element_stride_ == 0) {
83*4bdc9457SAndroid Build Coastguard Worker       return element_size_;
84*4bdc9457SAndroid Build Coastguard Worker     } else {
85*4bdc9457SAndroid Build Coastguard Worker       return output_element_stride_;
86*4bdc9457SAndroid Build Coastguard Worker     }
87*4bdc9457SAndroid Build Coastguard Worker   }
88*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)89*4bdc9457SAndroid Build Coastguard Worker   inline TransposeMicrokernelTester& iterations(size_t iterations) {
90*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
91*4bdc9457SAndroid Build Coastguard Worker     return *this;
92*4bdc9457SAndroid Build Coastguard Worker   }
93*4bdc9457SAndroid Build Coastguard Worker 
iterations()94*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const { return this->iterations_; }
95*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_transposev_ukernel_function transpose)96*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_transposev_ukernel_function transpose) const {
97*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(input_stride() * block_height() * input_element_stride() + XNN_EXTRA_BYTES);
98*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(output_stride() * block_width() * output_element_stride());
99*4bdc9457SAndroid Build Coastguard Worker     std::iota(input.begin(), input.end(), 0);
100*4bdc9457SAndroid Build Coastguard Worker     std::fill(output.begin(), output.end(), UINT8_C(0xA5));
101*4bdc9457SAndroid Build Coastguard Worker 
102*4bdc9457SAndroid Build Coastguard Worker     // Call optimized micro-kernel.
103*4bdc9457SAndroid Build Coastguard Worker     transpose(input.data(),
104*4bdc9457SAndroid Build Coastguard Worker               output.data(),
105*4bdc9457SAndroid Build Coastguard Worker               input_stride() * input_element_stride(),
106*4bdc9457SAndroid Build Coastguard Worker               output_stride() * output_element_stride(),
107*4bdc9457SAndroid Build Coastguard Worker               input_element_stride(),
108*4bdc9457SAndroid Build Coastguard Worker               output_element_stride(),
109*4bdc9457SAndroid Build Coastguard Worker               element_size(),
110*4bdc9457SAndroid Build Coastguard Worker               block_width(),
111*4bdc9457SAndroid Build Coastguard Worker               block_height());
112*4bdc9457SAndroid Build Coastguard Worker 
113*4bdc9457SAndroid Build Coastguard Worker     // Verify results.
114*4bdc9457SAndroid Build Coastguard Worker     for (size_t c = 0; c < block_width(); c++) {
115*4bdc9457SAndroid Build Coastguard Worker       for (size_t r = 0; r < block_height(); r++) {
116*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(std::memcmp(&input[input_element_stride() * (c+ r * input_stride())],
117*4bdc9457SAndroid Build Coastguard Worker                               &output[output_element_stride() * (r + c * output_stride())],
118*4bdc9457SAndroid Build Coastguard Worker                               element_size()), 0)
119*4bdc9457SAndroid Build Coastguard Worker             << "at row " << r << " / " << block_height()
120*4bdc9457SAndroid Build Coastguard Worker             << ", at column " << c << " / " << block_width();
121*4bdc9457SAndroid Build Coastguard Worker       }
122*4bdc9457SAndroid Build Coastguard Worker     }
123*4bdc9457SAndroid Build Coastguard Worker   }
124*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_x64_transposec_ukernel_function transpose)125*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_x64_transposec_ukernel_function transpose) const {
126*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint64_t> input(input_stride() * output_stride() + XNN_EXTRA_BYTES / sizeof(uint64_t));
127*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint64_t> output(input_stride() * output_stride());
128*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
129*4bdc9457SAndroid Build Coastguard Worker       std::iota(input.begin(), input.end(), 0);
130*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT64_C(0xBADC0FFEE0DDF00D));
131*4bdc9457SAndroid Build Coastguard Worker 
132*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
133*4bdc9457SAndroid Build Coastguard Worker       transpose(input.data(),
134*4bdc9457SAndroid Build Coastguard Worker                 output.data(),
135*4bdc9457SAndroid Build Coastguard Worker                 input_stride() * sizeof(uint64_t),
136*4bdc9457SAndroid Build Coastguard Worker                 output_stride() * sizeof(uint64_t),
137*4bdc9457SAndroid Build Coastguard Worker                 block_width(),
138*4bdc9457SAndroid Build Coastguard Worker                 block_height());
139*4bdc9457SAndroid Build Coastguard Worker 
140*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
141*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < block_width(); c++) {
142*4bdc9457SAndroid Build Coastguard Worker         for (size_t r = 0; r < block_height(); r++) {
143*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(input[c + r * input_stride()], output[r + c * output_stride()])
144*4bdc9457SAndroid Build Coastguard Worker               << "at row " << r << " / " << block_height()
145*4bdc9457SAndroid Build Coastguard Worker               << ", at column " << c << " / " << block_width();
146*4bdc9457SAndroid Build Coastguard Worker         }
147*4bdc9457SAndroid Build Coastguard Worker       }
148*4bdc9457SAndroid Build Coastguard Worker     }
149*4bdc9457SAndroid Build Coastguard Worker   }
150*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_x32_transposec_ukernel_function transpose)151*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_x32_transposec_ukernel_function transpose) const {
152*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> input(input_stride() * output_stride() + XNN_EXTRA_BYTES / sizeof(uint32_t));
153*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> output(input_stride() * output_stride());
154*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
155*4bdc9457SAndroid Build Coastguard Worker       std::iota(input.begin(), input.end(), 0);
156*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEEF));
157*4bdc9457SAndroid Build Coastguard Worker 
158*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
159*4bdc9457SAndroid Build Coastguard Worker       transpose(input.data(),
160*4bdc9457SAndroid Build Coastguard Worker                 output.data(),
161*4bdc9457SAndroid Build Coastguard Worker                 input_stride() * sizeof(uint32_t),
162*4bdc9457SAndroid Build Coastguard Worker                 output_stride() * sizeof(uint32_t),
163*4bdc9457SAndroid Build Coastguard Worker                 block_width(),
164*4bdc9457SAndroid Build Coastguard Worker                 block_height());
165*4bdc9457SAndroid Build Coastguard Worker 
166*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
167*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < block_width(); c++) {
168*4bdc9457SAndroid Build Coastguard Worker         for (size_t r = 0; r < block_height(); r++) {
169*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(input[c + r * input_stride()], output[r + c * output_stride()])
170*4bdc9457SAndroid Build Coastguard Worker               << "at row " << r << " / " << block_height()
171*4bdc9457SAndroid Build Coastguard Worker               << ", at column " << c << " / " << block_width();
172*4bdc9457SAndroid Build Coastguard Worker         }
173*4bdc9457SAndroid Build Coastguard Worker       }
174*4bdc9457SAndroid Build Coastguard Worker     }
175*4bdc9457SAndroid Build Coastguard Worker   }
176*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_x24_transposec_ukernel_function transpose)177*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_x24_transposec_ukernel_function transpose) const {
178*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(input_stride() * output_stride() * element_size() + XNN_EXTRA_BYTES);
179*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(input_stride() * output_stride() * element_size());
180*4bdc9457SAndroid Build Coastguard Worker     std::iota(input.begin(), input.end(), 0);
181*4bdc9457SAndroid Build Coastguard Worker     std::fill(output.begin(), output.end(), UINT8_C(0xA5));
182*4bdc9457SAndroid Build Coastguard Worker 
183*4bdc9457SAndroid Build Coastguard Worker     // Call optimized micro-kernel.
184*4bdc9457SAndroid Build Coastguard Worker     transpose(input.data(),
185*4bdc9457SAndroid Build Coastguard Worker               output.data(),
186*4bdc9457SAndroid Build Coastguard Worker               input_stride() * element_size(),
187*4bdc9457SAndroid Build Coastguard Worker               output_stride() * element_size(),
188*4bdc9457SAndroid Build Coastguard Worker               block_width(),
189*4bdc9457SAndroid Build Coastguard Worker               block_height());
190*4bdc9457SAndroid Build Coastguard Worker 
191*4bdc9457SAndroid Build Coastguard Worker     // Verify results.
192*4bdc9457SAndroid Build Coastguard Worker     for (size_t c = 0; c < block_width(); c++) {
193*4bdc9457SAndroid Build Coastguard Worker       for (size_t r = 0; r < block_height(); r++) {
194*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(std::memcmp(&input[element_size() * (c+ r * input_stride())],
195*4bdc9457SAndroid Build Coastguard Worker                               &output[element_size() * (r + c * output_stride())],
196*4bdc9457SAndroid Build Coastguard Worker                               element_size()), 0)
197*4bdc9457SAndroid Build Coastguard Worker             << "at row " << r << " / " << block_height()
198*4bdc9457SAndroid Build Coastguard Worker             << ", at column " << c << " / " << block_width();
199*4bdc9457SAndroid Build Coastguard Worker       }
200*4bdc9457SAndroid Build Coastguard Worker     }
201*4bdc9457SAndroid Build Coastguard Worker   }
202*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_x16_transposec_ukernel_function transpose)203*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_x16_transposec_ukernel_function transpose) const {
204*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(input_stride() * output_stride() + XNN_EXTRA_BYTES / sizeof(uint16_t));
205*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(input_stride() * output_stride());
206*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
207*4bdc9457SAndroid Build Coastguard Worker       std::iota(input.begin(), input.end(), 0);
208*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0xDEAD));
209*4bdc9457SAndroid Build Coastguard Worker 
210*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
211*4bdc9457SAndroid Build Coastguard Worker       transpose(input.data(),
212*4bdc9457SAndroid Build Coastguard Worker                 output.data(),
213*4bdc9457SAndroid Build Coastguard Worker                 input_stride() * sizeof(uint16_t),
214*4bdc9457SAndroid Build Coastguard Worker                 output_stride() * sizeof(uint16_t),
215*4bdc9457SAndroid Build Coastguard Worker                 block_width(),
216*4bdc9457SAndroid Build Coastguard Worker                 block_height());
217*4bdc9457SAndroid Build Coastguard Worker 
218*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
219*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < block_width(); c++) {
220*4bdc9457SAndroid Build Coastguard Worker         for (size_t r = 0; r < block_height(); r++) {
221*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(input[c + r * input_stride()], output[r + c * output_stride()])
222*4bdc9457SAndroid Build Coastguard Worker               << "at row " << r << " / " << block_height()
223*4bdc9457SAndroid Build Coastguard Worker               << ", at column " << c << " / " << block_width();
224*4bdc9457SAndroid Build Coastguard Worker         }
225*4bdc9457SAndroid Build Coastguard Worker       }
226*4bdc9457SAndroid Build Coastguard Worker     }
227*4bdc9457SAndroid Build Coastguard Worker   }
228*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_x8_transposec_ukernel_function transpose)229*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_x8_transposec_ukernel_function transpose) const {
230*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(input_stride() * output_stride() + XNN_EXTRA_BYTES);
231*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(input_stride() * output_stride());
232*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
233*4bdc9457SAndroid Build Coastguard Worker       std::iota(input.begin(), input.end(), 0);
234*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
235*4bdc9457SAndroid Build Coastguard Worker 
236*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
237*4bdc9457SAndroid Build Coastguard Worker       transpose(input.data(),
238*4bdc9457SAndroid Build Coastguard Worker                 output.data(),
239*4bdc9457SAndroid Build Coastguard Worker                 input_stride() * sizeof(uint8_t),
240*4bdc9457SAndroid Build Coastguard Worker                 output_stride() * sizeof(uint8_t),
241*4bdc9457SAndroid Build Coastguard Worker                 block_width(),
242*4bdc9457SAndroid Build Coastguard Worker                 block_height());
243*4bdc9457SAndroid Build Coastguard Worker 
244*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
245*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < block_width(); c++) {
246*4bdc9457SAndroid Build Coastguard Worker         for (size_t r = 0; r < block_height(); r++) {
247*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ((int)input[c + r * input_stride()], (int)output[r + c * output_stride()])
248*4bdc9457SAndroid Build Coastguard Worker               << "at row " << r << " / " << block_height()
249*4bdc9457SAndroid Build Coastguard Worker               << ", at column " << c << " / " << block_width();
250*4bdc9457SAndroid Build Coastguard Worker         }
251*4bdc9457SAndroid Build Coastguard Worker       }
252*4bdc9457SAndroid Build Coastguard Worker     }
253*4bdc9457SAndroid Build Coastguard Worker   }
254*4bdc9457SAndroid Build Coastguard Worker 
255*4bdc9457SAndroid Build Coastguard Worker  private:
256*4bdc9457SAndroid Build Coastguard Worker   size_t element_size_ = 1;
257*4bdc9457SAndroid Build Coastguard Worker   size_t input_stride_ = 1;
258*4bdc9457SAndroid Build Coastguard Worker   size_t output_stride_ = 1;
259*4bdc9457SAndroid Build Coastguard Worker   size_t input_element_stride_ = 0;
260*4bdc9457SAndroid Build Coastguard Worker   size_t output_element_stride_ = 0;
261*4bdc9457SAndroid Build Coastguard Worker   size_t block_height_ = 1;
262*4bdc9457SAndroid Build Coastguard Worker   size_t block_width_ = 1;
263*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_ = 15;
264*4bdc9457SAndroid Build Coastguard Worker };
265