xref: /aosp_15_r20/external/XNNPACK/test/transpose-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
9*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
10*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
11*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
12*4bdc9457SAndroid Build Coastguard Worker #include <functional>
13*4bdc9457SAndroid Build Coastguard Worker #include <numeric>
14*4bdc9457SAndroid Build Coastguard Worker #include <vector>
15*4bdc9457SAndroid Build Coastguard Worker 
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
17*4bdc9457SAndroid Build Coastguard Worker 
18*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
19*4bdc9457SAndroid Build Coastguard Worker 
reference_index(const size_t * input_stride,const size_t * output_stride,const size_t * perm,const size_t num_dims,size_t pos)20*4bdc9457SAndroid Build Coastguard Worker inline size_t reference_index(
21*4bdc9457SAndroid Build Coastguard Worker     const size_t* input_stride,
22*4bdc9457SAndroid Build Coastguard Worker     const size_t* output_stride,
23*4bdc9457SAndroid Build Coastguard Worker     const size_t* perm,
24*4bdc9457SAndroid Build Coastguard Worker     const size_t num_dims,
25*4bdc9457SAndroid Build Coastguard Worker     size_t pos)
26*4bdc9457SAndroid Build Coastguard Worker {
27*4bdc9457SAndroid Build Coastguard Worker   size_t in_pos = 0;
28*4bdc9457SAndroid Build Coastguard Worker   for (size_t j = 0; j < num_dims; ++j) {
29*4bdc9457SAndroid Build Coastguard Worker     const size_t idx = pos / output_stride[j];
30*4bdc9457SAndroid Build Coastguard Worker     pos = pos % output_stride[j];
31*4bdc9457SAndroid Build Coastguard Worker     in_pos += idx * input_stride[perm[j]];
32*4bdc9457SAndroid Build Coastguard Worker   }
33*4bdc9457SAndroid Build Coastguard Worker   return in_pos;
34*4bdc9457SAndroid Build Coastguard Worker }
35*4bdc9457SAndroid Build Coastguard Worker 
36*4bdc9457SAndroid Build Coastguard Worker class TransposeOperatorTester {
37*4bdc9457SAndroid Build Coastguard Worker  public:
num_dims(size_t num_dims)38*4bdc9457SAndroid Build Coastguard Worker   inline TransposeOperatorTester& num_dims(size_t num_dims) {
39*4bdc9457SAndroid Build Coastguard Worker     assert(num_dims != 0);
40*4bdc9457SAndroid Build Coastguard Worker     this->num_dims_ = num_dims;
41*4bdc9457SAndroid Build Coastguard Worker     return *this;
42*4bdc9457SAndroid Build Coastguard Worker   }
43*4bdc9457SAndroid Build Coastguard Worker 
num_dims()44*4bdc9457SAndroid Build Coastguard Worker   inline size_t num_dims() const { return this->num_dims_; }
45*4bdc9457SAndroid Build Coastguard Worker 
shape(std::vector<size_t> shape)46*4bdc9457SAndroid Build Coastguard Worker   inline TransposeOperatorTester& shape(std::vector<size_t> shape) {
47*4bdc9457SAndroid Build Coastguard Worker     assert(shape.size() <= XNN_MAX_TENSOR_DIMS);
48*4bdc9457SAndroid Build Coastguard Worker     this->shape_ = shape;
49*4bdc9457SAndroid Build Coastguard Worker     return *this;
50*4bdc9457SAndroid Build Coastguard Worker   }
51*4bdc9457SAndroid Build Coastguard Worker 
dims()52*4bdc9457SAndroid Build Coastguard Worker   inline const std::vector<size_t>& dims() const { return this->shape_; }
53*4bdc9457SAndroid Build Coastguard Worker 
perm(std::vector<size_t> perm)54*4bdc9457SAndroid Build Coastguard Worker   inline TransposeOperatorTester& perm(std::vector<size_t> perm) {
55*4bdc9457SAndroid Build Coastguard Worker     assert(perm.size() <= XNN_MAX_TENSOR_DIMS);
56*4bdc9457SAndroid Build Coastguard Worker     this->perm_ = perm;
57*4bdc9457SAndroid Build Coastguard Worker     return *this;
58*4bdc9457SAndroid Build Coastguard Worker   }
59*4bdc9457SAndroid Build Coastguard Worker 
perm()60*4bdc9457SAndroid Build Coastguard Worker   inline const std::vector<size_t>& perm() const { return this->perm_; }
61*4bdc9457SAndroid Build Coastguard Worker 
TestX8()62*4bdc9457SAndroid Build Coastguard Worker   void TestX8() const {
63*4bdc9457SAndroid Build Coastguard Worker     size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
64*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(count + XNN_EXTRA_BYTES / sizeof(uint8_t));
65*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(count);
66*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> input_stride(input.size(), 1);
67*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> output_stride(input.size(), 1);
68*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = num_dims() - 1; i > 0; --i) {
69*4bdc9457SAndroid Build Coastguard Worker       input_stride[i - 1] = input_stride[i] * shape_[i];
70*4bdc9457SAndroid Build Coastguard Worker       output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
71*4bdc9457SAndroid Build Coastguard Worker     }
72*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
73*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t transpose_op = nullptr;
74*4bdc9457SAndroid Build Coastguard Worker     std::iota(input.begin(), input.end(), 0);
75*4bdc9457SAndroid Build Coastguard Worker     std::fill(output.begin(), output.end(), UINT8_C(0xA5));
76*4bdc9457SAndroid Build Coastguard Worker 
77*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
78*4bdc9457SAndroid Build Coastguard Worker               xnn_create_transpose_nd_x8(0, &transpose_op));
79*4bdc9457SAndroid Build Coastguard Worker     ASSERT_NE(nullptr, transpose_op);
80*4bdc9457SAndroid Build Coastguard Worker 
81*4bdc9457SAndroid Build Coastguard Worker     // Smart pointer to automatically delete convert op.
82*4bdc9457SAndroid Build Coastguard Worker     std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_transpose_op(transpose_op, xnn_delete_operator);
83*4bdc9457SAndroid Build Coastguard Worker 
84*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
85*4bdc9457SAndroid Build Coastguard Worker               xnn_setup_transpose_nd_x8(
86*4bdc9457SAndroid Build Coastguard Worker                   transpose_op,
87*4bdc9457SAndroid Build Coastguard Worker                   input.data(), output.data(),
88*4bdc9457SAndroid Build Coastguard Worker                   num_dims(), shape_.data(), perm_.data(),
89*4bdc9457SAndroid Build Coastguard Worker                   nullptr /* thread pool */));
90*4bdc9457SAndroid Build Coastguard Worker 
91*4bdc9457SAndroid Build Coastguard Worker     // Run operator.
92*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
93*4bdc9457SAndroid Build Coastguard Worker               xnn_run_operator(transpose_op, nullptr /* thread pool */));
94*4bdc9457SAndroid Build Coastguard Worker 
95*4bdc9457SAndroid Build Coastguard Worker     // Verify results.
96*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < count; ++i) {
97*4bdc9457SAndroid Build Coastguard Worker       const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
98*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(input[in_idx], output[i]);
99*4bdc9457SAndroid Build Coastguard Worker     }
100*4bdc9457SAndroid Build Coastguard Worker   }
101*4bdc9457SAndroid Build Coastguard Worker 
TestRunX8()102*4bdc9457SAndroid Build Coastguard Worker     void TestRunX8() const {
103*4bdc9457SAndroid Build Coastguard Worker     const size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
104*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(count + XNN_EXTRA_BYTES / sizeof(uint8_t));
105*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(count);
106*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> input_stride(input.size(), 1);
107*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> output_stride(input.size(), 1);
108*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = num_dims() - 1; i > 0; --i) {
109*4bdc9457SAndroid Build Coastguard Worker       input_stride[i - 1] = input_stride[i] * shape_[i];
110*4bdc9457SAndroid Build Coastguard Worker       output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
111*4bdc9457SAndroid Build Coastguard Worker     }
112*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
113*4bdc9457SAndroid Build Coastguard Worker     std::iota(input.begin(), input.end(), 0);
114*4bdc9457SAndroid Build Coastguard Worker     std::fill(output.begin(), output.end(), UINT8_C(0xA5));
115*4bdc9457SAndroid Build Coastguard Worker 
116*4bdc9457SAndroid Build Coastguard Worker     // Call transpose eager API
117*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
118*4bdc9457SAndroid Build Coastguard Worker               xnn_run_transpose_nd_x8(
119*4bdc9457SAndroid Build Coastguard Worker                   0 /* flags */,
120*4bdc9457SAndroid Build Coastguard Worker                   input.data(), output.data(),
121*4bdc9457SAndroid Build Coastguard Worker                   num_dims(), shape_.data(), perm_.data(),
122*4bdc9457SAndroid Build Coastguard Worker                   nullptr /* thread pool */));
123*4bdc9457SAndroid Build Coastguard Worker 
124*4bdc9457SAndroid Build Coastguard Worker     // Verify results.
125*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < count; ++i) {
126*4bdc9457SAndroid Build Coastguard Worker       const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
127*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(input[in_idx], output[i]);
128*4bdc9457SAndroid Build Coastguard Worker     }
129*4bdc9457SAndroid Build Coastguard Worker   }
130*4bdc9457SAndroid Build Coastguard Worker 
TestX16()131*4bdc9457SAndroid Build Coastguard Worker   void TestX16() const {
132*4bdc9457SAndroid Build Coastguard Worker     size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
133*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(count + XNN_EXTRA_BYTES / sizeof(uint16_t));
134*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(count);
135*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> input_stride(input.size(), 1);
136*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> output_stride(input.size(), 1);
137*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = num_dims() - 1; i > 0; --i) {
138*4bdc9457SAndroid Build Coastguard Worker       input_stride[i - 1] = input_stride[i] * shape_[i];
139*4bdc9457SAndroid Build Coastguard Worker       output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
140*4bdc9457SAndroid Build Coastguard Worker     }
141*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
142*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t transpose_op = nullptr;
143*4bdc9457SAndroid Build Coastguard Worker     std::iota(input.begin(), input.end(), 0);
144*4bdc9457SAndroid Build Coastguard Worker     std::fill(output.begin(), output.end(), UINT16_C(0xDEAD));
145*4bdc9457SAndroid Build Coastguard Worker 
146*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
147*4bdc9457SAndroid Build Coastguard Worker               xnn_create_transpose_nd_x16(0, &transpose_op));
148*4bdc9457SAndroid Build Coastguard Worker     ASSERT_NE(nullptr, transpose_op);
149*4bdc9457SAndroid Build Coastguard Worker 
150*4bdc9457SAndroid Build Coastguard Worker     // Smart pointer to automatically delete convert op.
151*4bdc9457SAndroid Build Coastguard Worker     std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_transpose_op(transpose_op, xnn_delete_operator);
152*4bdc9457SAndroid Build Coastguard Worker 
153*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
154*4bdc9457SAndroid Build Coastguard Worker               xnn_setup_transpose_nd_x16(
155*4bdc9457SAndroid Build Coastguard Worker                   transpose_op,
156*4bdc9457SAndroid Build Coastguard Worker                   input.data(), output.data(),
157*4bdc9457SAndroid Build Coastguard Worker                   num_dims(), shape_.data(), perm_.data(),
158*4bdc9457SAndroid Build Coastguard Worker                   nullptr /* thread pool */));
159*4bdc9457SAndroid Build Coastguard Worker 
160*4bdc9457SAndroid Build Coastguard Worker     // Run operator.
161*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
162*4bdc9457SAndroid Build Coastguard Worker               xnn_run_operator(transpose_op, nullptr /* thread pool */));
163*4bdc9457SAndroid Build Coastguard Worker 
164*4bdc9457SAndroid Build Coastguard Worker     // Verify results.
165*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < count; ++i) {
166*4bdc9457SAndroid Build Coastguard Worker       const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
167*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(input[in_idx], output[i]);
168*4bdc9457SAndroid Build Coastguard Worker     }
169*4bdc9457SAndroid Build Coastguard Worker   }
170*4bdc9457SAndroid Build Coastguard Worker 
TestRunX16()171*4bdc9457SAndroid Build Coastguard Worker   void TestRunX16() const {
172*4bdc9457SAndroid Build Coastguard Worker     const size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
173*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(count + XNN_EXTRA_BYTES / sizeof(uint16_t));
174*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(count);
175*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> input_stride(input.size(), 1);
176*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> output_stride(input.size(), 1);
177*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = num_dims() - 1; i > 0; --i) {
178*4bdc9457SAndroid Build Coastguard Worker       input_stride[i - 1] = input_stride[i] * shape_[i];
179*4bdc9457SAndroid Build Coastguard Worker       output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
180*4bdc9457SAndroid Build Coastguard Worker     }
181*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
182*4bdc9457SAndroid Build Coastguard Worker     std::iota(input.begin(), input.end(), 0);
183*4bdc9457SAndroid Build Coastguard Worker     std::fill(output.begin(), output.end(), UINT16_C(0xDEADBEEF));
184*4bdc9457SAndroid Build Coastguard Worker 
185*4bdc9457SAndroid Build Coastguard Worker     // Call transpose eager API
186*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
187*4bdc9457SAndroid Build Coastguard Worker               xnn_run_transpose_nd_x16(
188*4bdc9457SAndroid Build Coastguard Worker                   0 /* flags */,
189*4bdc9457SAndroid Build Coastguard Worker                   input.data(), output.data(),
190*4bdc9457SAndroid Build Coastguard Worker                   num_dims(), shape_.data(), perm_.data(),
191*4bdc9457SAndroid Build Coastguard Worker                   nullptr /* thread pool */));
192*4bdc9457SAndroid Build Coastguard Worker 
193*4bdc9457SAndroid Build Coastguard Worker     // Verify results.
194*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < count; ++i) {
195*4bdc9457SAndroid Build Coastguard Worker       const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
196*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(input[in_idx], output[i]);
197*4bdc9457SAndroid Build Coastguard Worker     }
198*4bdc9457SAndroid Build Coastguard Worker   }
199*4bdc9457SAndroid Build Coastguard Worker 
TestX32()200*4bdc9457SAndroid Build Coastguard Worker   void TestX32() const {
201*4bdc9457SAndroid Build Coastguard Worker     size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
202*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> input(count + XNN_EXTRA_BYTES / sizeof(uint32_t));
203*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> output(count);
204*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> input_stride(input.size(), 1);
205*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> output_stride(input.size(), 1);
206*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = num_dims() - 1; i > 0; --i) {
207*4bdc9457SAndroid Build Coastguard Worker       input_stride[i - 1] = input_stride[i] * shape_[i];
208*4bdc9457SAndroid Build Coastguard Worker       output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
209*4bdc9457SAndroid Build Coastguard Worker     }
210*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
211*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t transpose_op = nullptr;
212*4bdc9457SAndroid Build Coastguard Worker     std::iota(input.begin(), input.end(), 0);
213*4bdc9457SAndroid Build Coastguard Worker     std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEEF));
214*4bdc9457SAndroid Build Coastguard Worker 
215*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
216*4bdc9457SAndroid Build Coastguard Worker               xnn_create_transpose_nd_x32(0, &transpose_op));
217*4bdc9457SAndroid Build Coastguard Worker     ASSERT_NE(nullptr, transpose_op);
218*4bdc9457SAndroid Build Coastguard Worker 
219*4bdc9457SAndroid Build Coastguard Worker     // Smart pointer to automatically delete convert op.
220*4bdc9457SAndroid Build Coastguard Worker     std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_transpose_op(transpose_op, xnn_delete_operator);
221*4bdc9457SAndroid Build Coastguard Worker 
222*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
223*4bdc9457SAndroid Build Coastguard Worker               xnn_setup_transpose_nd_x32(
224*4bdc9457SAndroid Build Coastguard Worker                   transpose_op,
225*4bdc9457SAndroid Build Coastguard Worker                   input.data(), output.data(),
226*4bdc9457SAndroid Build Coastguard Worker                   num_dims(), shape_.data(), perm_.data(),
227*4bdc9457SAndroid Build Coastguard Worker                   nullptr /* thread pool */));
228*4bdc9457SAndroid Build Coastguard Worker 
229*4bdc9457SAndroid Build Coastguard Worker     // Run operator.
230*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
231*4bdc9457SAndroid Build Coastguard Worker               xnn_run_operator(transpose_op, nullptr /* thread pool */));
232*4bdc9457SAndroid Build Coastguard Worker 
233*4bdc9457SAndroid Build Coastguard Worker     // Verify results.
234*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < count; ++i) {
235*4bdc9457SAndroid Build Coastguard Worker       const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
236*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(input[in_idx], output[i]);
237*4bdc9457SAndroid Build Coastguard Worker     }
238*4bdc9457SAndroid Build Coastguard Worker   }
239*4bdc9457SAndroid Build Coastguard Worker 
TestRunX32()240*4bdc9457SAndroid Build Coastguard Worker   void TestRunX32() const {
241*4bdc9457SAndroid Build Coastguard Worker     const size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
242*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> input(count + XNN_EXTRA_BYTES / sizeof(uint32_t));
243*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> output(count);
244*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> input_stride(input.size(), 1);
245*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> output_stride(input.size(), 1);
246*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = num_dims() - 1; i > 0; --i) {
247*4bdc9457SAndroid Build Coastguard Worker       input_stride[i - 1] = input_stride[i] * shape_[i];
248*4bdc9457SAndroid Build Coastguard Worker       output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
249*4bdc9457SAndroid Build Coastguard Worker     }
250*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
251*4bdc9457SAndroid Build Coastguard Worker     std::iota(input.begin(), input.end(), 0);
252*4bdc9457SAndroid Build Coastguard Worker     std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEEF));
253*4bdc9457SAndroid Build Coastguard Worker 
254*4bdc9457SAndroid Build Coastguard Worker     // Call transpose eager API
255*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success,
256*4bdc9457SAndroid Build Coastguard Worker               xnn_run_transpose_nd_x32(
257*4bdc9457SAndroid Build Coastguard Worker                   0,
258*4bdc9457SAndroid Build Coastguard Worker                   input.data(), output.data(),
259*4bdc9457SAndroid Build Coastguard Worker                   num_dims(), shape_.data(), perm_.data(),
260*4bdc9457SAndroid Build Coastguard Worker                   nullptr /* thread pool */));
261*4bdc9457SAndroid Build Coastguard Worker 
262*4bdc9457SAndroid Build Coastguard Worker     // Verify results.
263*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < count; ++i) {
264*4bdc9457SAndroid Build Coastguard Worker       const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
265*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(input[in_idx], output[i]);
266*4bdc9457SAndroid Build Coastguard Worker     }
267*4bdc9457SAndroid Build Coastguard Worker   }
268*4bdc9457SAndroid Build Coastguard Worker 
269*4bdc9457SAndroid Build Coastguard Worker  private:
270*4bdc9457SAndroid Build Coastguard Worker   size_t num_dims_ = 1;
271*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> shape_;
272*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> perm_;
273*4bdc9457SAndroid Build Coastguard Worker };
274