xref: /aosp_15_r20/external/XNNPACK/test/space-to-depth-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
12*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
15*4bdc9457SAndroid Build Coastguard Worker #include <functional>
16*4bdc9457SAndroid Build Coastguard Worker #include <random>
17*4bdc9457SAndroid Build Coastguard Worker #include <vector>
18*4bdc9457SAndroid Build Coastguard Worker 
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
20*4bdc9457SAndroid Build Coastguard Worker 
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker class SpaceToDepthOperatorTester {
23*4bdc9457SAndroid Build Coastguard Worker  public:
input_size(size_t input_height,size_t input_width)24*4bdc9457SAndroid Build Coastguard Worker   inline SpaceToDepthOperatorTester& input_size(size_t input_height, size_t input_width) {
25*4bdc9457SAndroid Build Coastguard Worker     assert(input_height >= 1);
26*4bdc9457SAndroid Build Coastguard Worker     assert(input_width >= 1);
27*4bdc9457SAndroid Build Coastguard Worker     this->input_height_ = input_height;
28*4bdc9457SAndroid Build Coastguard Worker     this->input_width_ = input_width;
29*4bdc9457SAndroid Build Coastguard Worker     return *this;
30*4bdc9457SAndroid Build Coastguard Worker   }
31*4bdc9457SAndroid Build Coastguard Worker 
input_height(size_t input_height)32*4bdc9457SAndroid Build Coastguard Worker   inline SpaceToDepthOperatorTester& input_height(size_t input_height) {
33*4bdc9457SAndroid Build Coastguard Worker     assert(input_height >= 1);
34*4bdc9457SAndroid Build Coastguard Worker     this->input_height_ = input_height;
35*4bdc9457SAndroid Build Coastguard Worker     return *this;
36*4bdc9457SAndroid Build Coastguard Worker   }
37*4bdc9457SAndroid Build Coastguard Worker 
input_height()38*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_height() const {
39*4bdc9457SAndroid Build Coastguard Worker     return this->input_height_;
40*4bdc9457SAndroid Build Coastguard Worker   }
41*4bdc9457SAndroid Build Coastguard Worker 
input_width(size_t input_width)42*4bdc9457SAndroid Build Coastguard Worker   inline SpaceToDepthOperatorTester& input_width(size_t input_width) {
43*4bdc9457SAndroid Build Coastguard Worker     assert(input_width >= 1);
44*4bdc9457SAndroid Build Coastguard Worker     this->input_width_ = input_width;
45*4bdc9457SAndroid Build Coastguard Worker     return *this;
46*4bdc9457SAndroid Build Coastguard Worker   }
47*4bdc9457SAndroid Build Coastguard Worker 
input_width()48*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_width() const {
49*4bdc9457SAndroid Build Coastguard Worker     return this->input_width_;
50*4bdc9457SAndroid Build Coastguard Worker   }
51*4bdc9457SAndroid Build Coastguard Worker 
output_height()52*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_height() const {
53*4bdc9457SAndroid Build Coastguard Worker     assert(input_height() % block_size() == 0);
54*4bdc9457SAndroid Build Coastguard Worker     return input_height() / block_size();
55*4bdc9457SAndroid Build Coastguard Worker   }
56*4bdc9457SAndroid Build Coastguard Worker 
output_width()57*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_width() const {
58*4bdc9457SAndroid Build Coastguard Worker     assert(input_width() % block_size() == 0);
59*4bdc9457SAndroid Build Coastguard Worker     return input_width() / block_size();
60*4bdc9457SAndroid Build Coastguard Worker   }
61*4bdc9457SAndroid Build Coastguard Worker 
block_size(size_t block_size)62*4bdc9457SAndroid Build Coastguard Worker   inline SpaceToDepthOperatorTester& block_size(size_t block_size) {
63*4bdc9457SAndroid Build Coastguard Worker     assert(block_size >= 2);
64*4bdc9457SAndroid Build Coastguard Worker     this->block_size_ = block_size;
65*4bdc9457SAndroid Build Coastguard Worker     return *this;
66*4bdc9457SAndroid Build Coastguard Worker   }
67*4bdc9457SAndroid Build Coastguard Worker 
block_size()68*4bdc9457SAndroid Build Coastguard Worker   inline size_t block_size() const {
69*4bdc9457SAndroid Build Coastguard Worker     return this->block_size_;
70*4bdc9457SAndroid Build Coastguard Worker   }
71*4bdc9457SAndroid Build Coastguard Worker 
input_channels(size_t input_channels)72*4bdc9457SAndroid Build Coastguard Worker   inline SpaceToDepthOperatorTester& input_channels(size_t input_channels) {
73*4bdc9457SAndroid Build Coastguard Worker     assert(input_channels != 0);
74*4bdc9457SAndroid Build Coastguard Worker     this->input_channels_ = input_channels;
75*4bdc9457SAndroid Build Coastguard Worker     return *this;
76*4bdc9457SAndroid Build Coastguard Worker   }
77*4bdc9457SAndroid Build Coastguard Worker 
input_channels()78*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_channels() const {
79*4bdc9457SAndroid Build Coastguard Worker     return this->input_channels_;
80*4bdc9457SAndroid Build Coastguard Worker   }
81*4bdc9457SAndroid Build Coastguard Worker 
output_channels()82*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_channels() const {
83*4bdc9457SAndroid Build Coastguard Worker     return input_channels() * block_size() * block_size();
84*4bdc9457SAndroid Build Coastguard Worker   }
85*4bdc9457SAndroid Build Coastguard Worker 
batch_size(size_t batch_size)86*4bdc9457SAndroid Build Coastguard Worker   inline SpaceToDepthOperatorTester& batch_size(size_t batch_size) {
87*4bdc9457SAndroid Build Coastguard Worker     assert(batch_size != 0);
88*4bdc9457SAndroid Build Coastguard Worker     this->batch_size_ = batch_size;
89*4bdc9457SAndroid Build Coastguard Worker     return *this;
90*4bdc9457SAndroid Build Coastguard Worker   }
91*4bdc9457SAndroid Build Coastguard Worker 
batch_size()92*4bdc9457SAndroid Build Coastguard Worker   inline size_t batch_size() const {
93*4bdc9457SAndroid Build Coastguard Worker     return this->batch_size_;
94*4bdc9457SAndroid Build Coastguard Worker   }
95*4bdc9457SAndroid Build Coastguard Worker 
input_channels_stride(size_t input_channels_stride)96*4bdc9457SAndroid Build Coastguard Worker   inline SpaceToDepthOperatorTester& input_channels_stride(size_t input_channels_stride) {
97*4bdc9457SAndroid Build Coastguard Worker     assert(input_channels_stride >= 1);
98*4bdc9457SAndroid Build Coastguard Worker     this->input_channels_stride_ = input_channels_stride;
99*4bdc9457SAndroid Build Coastguard Worker     return *this;
100*4bdc9457SAndroid Build Coastguard Worker   }
101*4bdc9457SAndroid Build Coastguard Worker 
input_channels_stride()102*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_channels_stride() const {
103*4bdc9457SAndroid Build Coastguard Worker     if (this->input_channels_stride_ == 0) {
104*4bdc9457SAndroid Build Coastguard Worker       return input_channels();
105*4bdc9457SAndroid Build Coastguard Worker     } else {
106*4bdc9457SAndroid Build Coastguard Worker       assert(this->input_channels_stride_ >= input_channels());
107*4bdc9457SAndroid Build Coastguard Worker       return this->input_channels_stride_;
108*4bdc9457SAndroid Build Coastguard Worker     }
109*4bdc9457SAndroid Build Coastguard Worker   }
110*4bdc9457SAndroid Build Coastguard Worker 
output_channels_stride(size_t output_channels_stride)111*4bdc9457SAndroid Build Coastguard Worker   inline SpaceToDepthOperatorTester& output_channels_stride(size_t output_channels_stride) {
112*4bdc9457SAndroid Build Coastguard Worker     assert(output_channels_stride >= 1);
113*4bdc9457SAndroid Build Coastguard Worker     this->output_channels_stride_ = output_channels_stride;
114*4bdc9457SAndroid Build Coastguard Worker     return *this;
115*4bdc9457SAndroid Build Coastguard Worker   }
116*4bdc9457SAndroid Build Coastguard Worker 
output_channels_stride()117*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_channels_stride() const {
118*4bdc9457SAndroid Build Coastguard Worker     if (this->output_channels_stride_ == 0) {
119*4bdc9457SAndroid Build Coastguard Worker       return output_channels();
120*4bdc9457SAndroid Build Coastguard Worker     } else {
121*4bdc9457SAndroid Build Coastguard Worker       assert(this->output_channels_stride_ >= output_channels());
122*4bdc9457SAndroid Build Coastguard Worker       return this->output_channels_stride_;
123*4bdc9457SAndroid Build Coastguard Worker     }
124*4bdc9457SAndroid Build Coastguard Worker   }
125*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)126*4bdc9457SAndroid Build Coastguard Worker   inline SpaceToDepthOperatorTester& iterations(size_t iterations) {
127*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
128*4bdc9457SAndroid Build Coastguard Worker     return *this;
129*4bdc9457SAndroid Build Coastguard Worker   }
130*4bdc9457SAndroid Build Coastguard Worker 
iterations()131*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
132*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
133*4bdc9457SAndroid Build Coastguard Worker   }
134*4bdc9457SAndroid Build Coastguard Worker 
TestNHWCxX8()135*4bdc9457SAndroid Build Coastguard Worker   void TestNHWCxX8() const {
136*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(
137*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * input_height() * input_width() - 1) * input_channels_stride() + input_channels());
138*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output(
139*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * output_height() * output_width() - 1) * output_channels_stride() + output_channels());
140*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
141*4bdc9457SAndroid Build Coastguard Worker       std::iota(input.begin(), input.end(), 0);
142*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xAF));
143*4bdc9457SAndroid Build Coastguard Worker 
144*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Depth To Space operator.
145*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
146*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t space_to_depth_op = nullptr;
147*4bdc9457SAndroid Build Coastguard Worker 
148*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
149*4bdc9457SAndroid Build Coastguard Worker                 xnn_create_space_to_depth_nhwc_x8(
150*4bdc9457SAndroid Build Coastguard Worker                     input_channels(), input_channels_stride(), output_channels_stride(),
151*4bdc9457SAndroid Build Coastguard Worker                     block_size(), 0, &space_to_depth_op));
152*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, space_to_depth_op);
153*4bdc9457SAndroid Build Coastguard Worker 
154*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete space_to_depth_op.
155*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_space_to_depth_op(space_to_depth_op, xnn_delete_operator);
156*4bdc9457SAndroid Build Coastguard Worker 
157*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
158*4bdc9457SAndroid Build Coastguard Worker                 xnn_setup_space_to_depth_nhwc_x8(
159*4bdc9457SAndroid Build Coastguard Worker                     space_to_depth_op,
160*4bdc9457SAndroid Build Coastguard Worker                     batch_size(), input_height(), input_width(),
161*4bdc9457SAndroid Build Coastguard Worker                     input.data(), output.data(), nullptr /* thread pool */));
162*4bdc9457SAndroid Build Coastguard Worker 
163*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
164*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(space_to_depth_op, nullptr /* thread pool */));
165*4bdc9457SAndroid Build Coastguard Worker 
166*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
167*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
168*4bdc9457SAndroid Build Coastguard Worker         for (size_t iy = 0; iy < output_height(); iy++) {
169*4bdc9457SAndroid Build Coastguard Worker           for (size_t ix = 0; ix < output_width(); ix++) {
170*4bdc9457SAndroid Build Coastguard Worker             for (size_t by = 0; by < block_size(); by++) {
171*4bdc9457SAndroid Build Coastguard Worker               for (size_t bx = 0; bx < block_size(); bx++) {
172*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < input_channels(); oc++) {
173*4bdc9457SAndroid Build Coastguard Worker                   const size_t input_index = oc
174*4bdc9457SAndroid Build Coastguard Worker                       + bx * input_channels_stride()
175*4bdc9457SAndroid Build Coastguard Worker                       + ix * block_size() * input_channels_stride()
176*4bdc9457SAndroid Build Coastguard Worker                       + by * output_width() * block_size() * input_channels_stride()
177*4bdc9457SAndroid Build Coastguard Worker                       + iy * block_size() * output_width() * block_size() * input_channels_stride()
178*4bdc9457SAndroid Build Coastguard Worker                       + i * output_height() * block_size() * output_width() * block_size() * input_channels_stride();
179*4bdc9457SAndroid Build Coastguard Worker                   const size_t output_index = oc
180*4bdc9457SAndroid Build Coastguard Worker                       + bx * input_channels()
181*4bdc9457SAndroid Build Coastguard Worker                       + by * input_channels() * block_size()
182*4bdc9457SAndroid Build Coastguard Worker                       + ix * output_channels_stride()
183*4bdc9457SAndroid Build Coastguard Worker                       + iy * output_width() * output_channels_stride()
184*4bdc9457SAndroid Build Coastguard Worker                       + i * output_height() * output_width() * output_channels_stride();
185*4bdc9457SAndroid Build Coastguard Worker 
186*4bdc9457SAndroid Build Coastguard Worker                   ASSERT_EQ(int32_t(output[output_index]), int32_t(input[input_index]))
187*4bdc9457SAndroid Build Coastguard Worker                     << "batch: " << i << " / " << batch_size()
188*4bdc9457SAndroid Build Coastguard Worker                     << ", output x: " << ix << " / " << output_width()
189*4bdc9457SAndroid Build Coastguard Worker                     << ", output y: " << iy << " / " << output_height()
190*4bdc9457SAndroid Build Coastguard Worker                     << ", block x: " << bx << " / " << block_size()
191*4bdc9457SAndroid Build Coastguard Worker                     << ", block y: " << by << " / " << block_size()
192*4bdc9457SAndroid Build Coastguard Worker                     << ", input channel: " << oc << " / " << input_channels()
193*4bdc9457SAndroid Build Coastguard Worker                     << ", input stride: " << input_channels_stride()
194*4bdc9457SAndroid Build Coastguard Worker                     << ", output stride: " << output_channels_stride();
195*4bdc9457SAndroid Build Coastguard Worker                 }
196*4bdc9457SAndroid Build Coastguard Worker               }
197*4bdc9457SAndroid Build Coastguard Worker             }
198*4bdc9457SAndroid Build Coastguard Worker           }
199*4bdc9457SAndroid Build Coastguard Worker         }
200*4bdc9457SAndroid Build Coastguard Worker       }
201*4bdc9457SAndroid Build Coastguard Worker     }
202*4bdc9457SAndroid Build Coastguard Worker   }
203*4bdc9457SAndroid Build Coastguard Worker 
TestNHWCxX16()204*4bdc9457SAndroid Build Coastguard Worker   void TestNHWCxX16() const {
205*4bdc9457SAndroid Build Coastguard Worker     std::vector<int16_t> input(
206*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * input_height() * input_width() - 1) * input_channels_stride() + input_channels());
207*4bdc9457SAndroid Build Coastguard Worker     std::vector<int16_t> output(
208*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * output_height() * output_width() - 1) * output_channels_stride() + output_channels());
209*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
210*4bdc9457SAndroid Build Coastguard Worker       std::iota(input.begin(), input.end(), 0);
211*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT16_C(0xDEAD));
212*4bdc9457SAndroid Build Coastguard Worker 
213*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Depth To Space operator.
214*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
215*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t space_to_depth_op = nullptr;
216*4bdc9457SAndroid Build Coastguard Worker 
217*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
218*4bdc9457SAndroid Build Coastguard Worker                 xnn_create_space_to_depth_nhwc_x16(
219*4bdc9457SAndroid Build Coastguard Worker                     input_channels(), input_channels_stride(), output_channels_stride(),
220*4bdc9457SAndroid Build Coastguard Worker                     block_size(), 0, &space_to_depth_op));
221*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, space_to_depth_op);
222*4bdc9457SAndroid Build Coastguard Worker 
223*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete space_to_depth_op.
224*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_space_to_depth_op(space_to_depth_op, xnn_delete_operator);
225*4bdc9457SAndroid Build Coastguard Worker 
226*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
227*4bdc9457SAndroid Build Coastguard Worker                 xnn_setup_space_to_depth_nhwc_x16(
228*4bdc9457SAndroid Build Coastguard Worker                     space_to_depth_op,
229*4bdc9457SAndroid Build Coastguard Worker                     batch_size(), input_height(), input_width(),
230*4bdc9457SAndroid Build Coastguard Worker                     input.data(), output.data(), nullptr /* thread pool */));
231*4bdc9457SAndroid Build Coastguard Worker 
232*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
233*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(space_to_depth_op, nullptr /* thread pool */));
234*4bdc9457SAndroid Build Coastguard Worker 
235*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
236*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
237*4bdc9457SAndroid Build Coastguard Worker         for (size_t iy = 0; iy < output_height(); iy++) {
238*4bdc9457SAndroid Build Coastguard Worker           for (size_t ix = 0; ix < output_width(); ix++) {
239*4bdc9457SAndroid Build Coastguard Worker             for (size_t by = 0; by < block_size(); by++) {
240*4bdc9457SAndroid Build Coastguard Worker               for (size_t bx = 0; bx < block_size(); bx++) {
241*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < input_channels(); oc++) {
242*4bdc9457SAndroid Build Coastguard Worker                   const size_t input_index = oc
243*4bdc9457SAndroid Build Coastguard Worker                       + bx * input_channels_stride()
244*4bdc9457SAndroid Build Coastguard Worker                       + ix * block_size() * input_channels_stride()
245*4bdc9457SAndroid Build Coastguard Worker                       + by * output_width() * block_size() * input_channels_stride()
246*4bdc9457SAndroid Build Coastguard Worker                       + iy * block_size() * output_width() * block_size() * input_channels_stride()
247*4bdc9457SAndroid Build Coastguard Worker                       + i * output_height() * block_size() * output_width() * block_size() * input_channels_stride();
248*4bdc9457SAndroid Build Coastguard Worker                   const size_t output_index = oc
249*4bdc9457SAndroid Build Coastguard Worker                       + bx * input_channels()
250*4bdc9457SAndroid Build Coastguard Worker                       + by * input_channels() * block_size()
251*4bdc9457SAndroid Build Coastguard Worker                       + ix * output_channels_stride()
252*4bdc9457SAndroid Build Coastguard Worker                       + iy * output_width() * output_channels_stride()
253*4bdc9457SAndroid Build Coastguard Worker                       + i * output_height() * output_width() * output_channels_stride();
254*4bdc9457SAndroid Build Coastguard Worker 
255*4bdc9457SAndroid Build Coastguard Worker                   ASSERT_EQ(int32_t(output[output_index]), int32_t(input[input_index]))
256*4bdc9457SAndroid Build Coastguard Worker                     << "batch: " << i << " / " << batch_size()
257*4bdc9457SAndroid Build Coastguard Worker                     << ", output x: " << ix << " / " << output_width()
258*4bdc9457SAndroid Build Coastguard Worker                     << ", output y: " << iy << " / " << output_height()
259*4bdc9457SAndroid Build Coastguard Worker                     << ", block x: " << bx << " / " << block_size()
260*4bdc9457SAndroid Build Coastguard Worker                     << ", block y: " << by << " / " << block_size()
261*4bdc9457SAndroid Build Coastguard Worker                     << ", input channel: " << oc << " / " << input_channels()
262*4bdc9457SAndroid Build Coastguard Worker                     << ", input stride: " << input_channels_stride()
263*4bdc9457SAndroid Build Coastguard Worker                     << ", output stride: " << output_channels_stride();
264*4bdc9457SAndroid Build Coastguard Worker                 }
265*4bdc9457SAndroid Build Coastguard Worker               }
266*4bdc9457SAndroid Build Coastguard Worker             }
267*4bdc9457SAndroid Build Coastguard Worker           }
268*4bdc9457SAndroid Build Coastguard Worker         }
269*4bdc9457SAndroid Build Coastguard Worker       }
270*4bdc9457SAndroid Build Coastguard Worker     }
271*4bdc9457SAndroid Build Coastguard Worker   }
272*4bdc9457SAndroid Build Coastguard Worker 
TestNHWCxX32()273*4bdc9457SAndroid Build Coastguard Worker   void TestNHWCxX32() const {
274*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> input(
275*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * input_height() * input_width() - 1) * input_channels_stride() + input_channels());
276*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> output(
277*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * output_height() * output_width() - 1) * output_channels_stride() + output_channels());
278*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
279*4bdc9457SAndroid Build Coastguard Worker       std::iota(input.begin(), input.end(), 0);
280*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT32_C(0xDEADBEEF));
281*4bdc9457SAndroid Build Coastguard Worker 
282*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Depth To Space operator.
283*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
284*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t space_to_depth_op = nullptr;
285*4bdc9457SAndroid Build Coastguard Worker 
286*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
287*4bdc9457SAndroid Build Coastguard Worker                 xnn_create_space_to_depth_nhwc_x32(
288*4bdc9457SAndroid Build Coastguard Worker                     input_channels(), input_channels_stride(), output_channels_stride(),
289*4bdc9457SAndroid Build Coastguard Worker                     block_size(), 0, &space_to_depth_op));
290*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, space_to_depth_op);
291*4bdc9457SAndroid Build Coastguard Worker 
292*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete space_to_depth_op.
293*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_space_to_depth_op(space_to_depth_op, xnn_delete_operator);
294*4bdc9457SAndroid Build Coastguard Worker 
295*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
296*4bdc9457SAndroid Build Coastguard Worker                 xnn_setup_space_to_depth_nhwc_x32(
297*4bdc9457SAndroid Build Coastguard Worker                     space_to_depth_op,
298*4bdc9457SAndroid Build Coastguard Worker                     batch_size(), input_height(), input_width(),
299*4bdc9457SAndroid Build Coastguard Worker                     input.data(), output.data(), nullptr /* thread pool */));
300*4bdc9457SAndroid Build Coastguard Worker 
301*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
302*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(space_to_depth_op, nullptr /* thread pool */));
303*4bdc9457SAndroid Build Coastguard Worker 
304*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
305*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
306*4bdc9457SAndroid Build Coastguard Worker         for (size_t iy = 0; iy < output_height(); iy++) {
307*4bdc9457SAndroid Build Coastguard Worker           for (size_t ix = 0; ix < output_width(); ix++) {
308*4bdc9457SAndroid Build Coastguard Worker             for (size_t by = 0; by < block_size(); by++) {
309*4bdc9457SAndroid Build Coastguard Worker               for (size_t bx = 0; bx < block_size(); bx++) {
310*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < input_channels(); oc++) {
311*4bdc9457SAndroid Build Coastguard Worker                   const size_t input_index = oc
312*4bdc9457SAndroid Build Coastguard Worker                       + bx * input_channels_stride()
313*4bdc9457SAndroid Build Coastguard Worker                       + ix * block_size() * input_channels_stride()
314*4bdc9457SAndroid Build Coastguard Worker                       + by * output_width() * block_size() * input_channels_stride()
315*4bdc9457SAndroid Build Coastguard Worker                       + iy * block_size() * output_width() * block_size() * input_channels_stride()
316*4bdc9457SAndroid Build Coastguard Worker                       + i * output_height() * block_size() * output_width() * block_size() * input_channels_stride();
317*4bdc9457SAndroid Build Coastguard Worker                   const size_t output_index = oc
318*4bdc9457SAndroid Build Coastguard Worker                       + bx * input_channels()
319*4bdc9457SAndroid Build Coastguard Worker                       + by * input_channels() * block_size()
320*4bdc9457SAndroid Build Coastguard Worker                       + ix * output_channels_stride()
321*4bdc9457SAndroid Build Coastguard Worker                       + iy * output_width() * output_channels_stride()
322*4bdc9457SAndroid Build Coastguard Worker                       + i * output_height() * output_width() * output_channels_stride();
323*4bdc9457SAndroid Build Coastguard Worker 
324*4bdc9457SAndroid Build Coastguard Worker                   ASSERT_EQ(int32_t(output[output_index]), int32_t(input[input_index]))
325*4bdc9457SAndroid Build Coastguard Worker                     << "batch: " << i << " / " << batch_size()
326*4bdc9457SAndroid Build Coastguard Worker                     << ", output x: " << ix << " / " << output_width()
327*4bdc9457SAndroid Build Coastguard Worker                     << ", output y: " << iy << " / " << output_height()
328*4bdc9457SAndroid Build Coastguard Worker                     << ", block x: " << bx << " / " << block_size()
329*4bdc9457SAndroid Build Coastguard Worker                     << ", block y: " << by << " / " << block_size()
330*4bdc9457SAndroid Build Coastguard Worker                     << ", input channel: " << oc << " / " << input_channels()
331*4bdc9457SAndroid Build Coastguard Worker                     << ", input stride: " << input_channels_stride()
332*4bdc9457SAndroid Build Coastguard Worker                     << ", output stride: " << output_channels_stride();
333*4bdc9457SAndroid Build Coastguard Worker                 }
334*4bdc9457SAndroid Build Coastguard Worker               }
335*4bdc9457SAndroid Build Coastguard Worker             }
336*4bdc9457SAndroid Build Coastguard Worker           }
337*4bdc9457SAndroid Build Coastguard Worker         }
338*4bdc9457SAndroid Build Coastguard Worker       }
339*4bdc9457SAndroid Build Coastguard Worker     }
340*4bdc9457SAndroid Build Coastguard Worker   }
341*4bdc9457SAndroid Build Coastguard Worker 
342*4bdc9457SAndroid Build Coastguard Worker  private:
343*4bdc9457SAndroid Build Coastguard Worker   size_t input_height_{1};
344*4bdc9457SAndroid Build Coastguard Worker   size_t input_width_{1};
345*4bdc9457SAndroid Build Coastguard Worker   size_t input_channels_{1};
346*4bdc9457SAndroid Build Coastguard Worker   size_t block_size_{2};
347*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size_{1};
348*4bdc9457SAndroid Build Coastguard Worker   size_t input_channels_stride_{0};
349*4bdc9457SAndroid Build Coastguard Worker   size_t output_channels_stride_{0};
350*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{1};
351*4bdc9457SAndroid Build Coastguard Worker };
352