xref: /aosp_15_r20/external/XNNPACK/test/channel-shuffle-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #pragma once
10*4bdc9457SAndroid Build Coastguard Worker 
11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
14*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
15*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
16*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
17*4bdc9457SAndroid Build Coastguard Worker #include <limits>
18*4bdc9457SAndroid Build Coastguard Worker #include <random>
19*4bdc9457SAndroid Build Coastguard Worker #include <vector>
20*4bdc9457SAndroid Build Coastguard Worker 
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
22*4bdc9457SAndroid Build Coastguard Worker 
23*4bdc9457SAndroid Build Coastguard Worker 
24*4bdc9457SAndroid Build Coastguard Worker class ChannelShuffleOperatorTester {
25*4bdc9457SAndroid Build Coastguard Worker  public:
groups(size_t groups)26*4bdc9457SAndroid Build Coastguard Worker   inline ChannelShuffleOperatorTester& groups(size_t groups) {
27*4bdc9457SAndroid Build Coastguard Worker     assert(groups != 0);
28*4bdc9457SAndroid Build Coastguard Worker     this->groups_ = groups;
29*4bdc9457SAndroid Build Coastguard Worker     return *this;
30*4bdc9457SAndroid Build Coastguard Worker   }
31*4bdc9457SAndroid Build Coastguard Worker 
groups()32*4bdc9457SAndroid Build Coastguard Worker   inline size_t groups() const {
33*4bdc9457SAndroid Build Coastguard Worker     return this->groups_;
34*4bdc9457SAndroid Build Coastguard Worker   }
35*4bdc9457SAndroid Build Coastguard Worker 
group_channels(size_t group_channels)36*4bdc9457SAndroid Build Coastguard Worker   inline ChannelShuffleOperatorTester& group_channels(size_t group_channels) {
37*4bdc9457SAndroid Build Coastguard Worker     assert(group_channels != 0);
38*4bdc9457SAndroid Build Coastguard Worker     this->group_channels_ = group_channels;
39*4bdc9457SAndroid Build Coastguard Worker     return *this;
40*4bdc9457SAndroid Build Coastguard Worker   }
41*4bdc9457SAndroid Build Coastguard Worker 
group_channels()42*4bdc9457SAndroid Build Coastguard Worker   inline size_t group_channels() const {
43*4bdc9457SAndroid Build Coastguard Worker     return this->group_channels_;
44*4bdc9457SAndroid Build Coastguard Worker   }
45*4bdc9457SAndroid Build Coastguard Worker 
channels()46*4bdc9457SAndroid Build Coastguard Worker   inline size_t channels() const {
47*4bdc9457SAndroid Build Coastguard Worker     return groups() * group_channels();
48*4bdc9457SAndroid Build Coastguard Worker   }
49*4bdc9457SAndroid Build Coastguard Worker 
input_stride(size_t input_stride)50*4bdc9457SAndroid Build Coastguard Worker   inline ChannelShuffleOperatorTester& input_stride(size_t input_stride) {
51*4bdc9457SAndroid Build Coastguard Worker     assert(input_stride != 0);
52*4bdc9457SAndroid Build Coastguard Worker     this->input_stride_ = input_stride;
53*4bdc9457SAndroid Build Coastguard Worker     return *this;
54*4bdc9457SAndroid Build Coastguard Worker   }
55*4bdc9457SAndroid Build Coastguard Worker 
input_stride()56*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_stride() const {
57*4bdc9457SAndroid Build Coastguard Worker     if (this->input_stride_ == 0) {
58*4bdc9457SAndroid Build Coastguard Worker       return channels();
59*4bdc9457SAndroid Build Coastguard Worker     } else {
60*4bdc9457SAndroid Build Coastguard Worker       assert(this->input_stride_ >= channels());
61*4bdc9457SAndroid Build Coastguard Worker       return this->input_stride_;
62*4bdc9457SAndroid Build Coastguard Worker     }
63*4bdc9457SAndroid Build Coastguard Worker   }
64*4bdc9457SAndroid Build Coastguard Worker 
output_stride(size_t output_stride)65*4bdc9457SAndroid Build Coastguard Worker   inline ChannelShuffleOperatorTester& output_stride(size_t output_stride) {
66*4bdc9457SAndroid Build Coastguard Worker     assert(output_stride != 0);
67*4bdc9457SAndroid Build Coastguard Worker     this->output_stride_ = output_stride;
68*4bdc9457SAndroid Build Coastguard Worker     return *this;
69*4bdc9457SAndroid Build Coastguard Worker   }
70*4bdc9457SAndroid Build Coastguard Worker 
output_stride()71*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_stride() const {
72*4bdc9457SAndroid Build Coastguard Worker     if (this->output_stride_ == 0) {
73*4bdc9457SAndroid Build Coastguard Worker       return channels();
74*4bdc9457SAndroid Build Coastguard Worker     } else {
75*4bdc9457SAndroid Build Coastguard Worker       assert(this->output_stride_ >= channels());
76*4bdc9457SAndroid Build Coastguard Worker       return this->output_stride_;
77*4bdc9457SAndroid Build Coastguard Worker     }
78*4bdc9457SAndroid Build Coastguard Worker   }
79*4bdc9457SAndroid Build Coastguard Worker 
batch_size(size_t batch_size)80*4bdc9457SAndroid Build Coastguard Worker   inline ChannelShuffleOperatorTester& batch_size(size_t batch_size) {
81*4bdc9457SAndroid Build Coastguard Worker     assert(batch_size != 0);
82*4bdc9457SAndroid Build Coastguard Worker     this->batch_size_ = batch_size;
83*4bdc9457SAndroid Build Coastguard Worker     return *this;
84*4bdc9457SAndroid Build Coastguard Worker   }
85*4bdc9457SAndroid Build Coastguard Worker 
batch_size()86*4bdc9457SAndroid Build Coastguard Worker   inline size_t batch_size() const {
87*4bdc9457SAndroid Build Coastguard Worker     return this->batch_size_;
88*4bdc9457SAndroid Build Coastguard Worker   }
89*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)90*4bdc9457SAndroid Build Coastguard Worker   inline ChannelShuffleOperatorTester& iterations(size_t iterations) {
91*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
92*4bdc9457SAndroid Build Coastguard Worker     return *this;
93*4bdc9457SAndroid Build Coastguard Worker   }
94*4bdc9457SAndroid Build Coastguard Worker 
iterations()95*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
96*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
97*4bdc9457SAndroid Build Coastguard Worker   }
98*4bdc9457SAndroid Build Coastguard Worker 
TestX8()99*4bdc9457SAndroid Build Coastguard Worker   void TestX8() const {
100*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
101*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
102*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
103*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
104*4bdc9457SAndroid Build Coastguard Worker 
105*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + (batch_size() - 1) * input_stride() + channels());
106*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
107*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
108*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
109*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
110*4bdc9457SAndroid Build Coastguard Worker 
111*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Channel Shuffle operator.
112*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
113*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t channel_shuffle_op = nullptr;
114*4bdc9457SAndroid Build Coastguard Worker 
115*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
116*4bdc9457SAndroid Build Coastguard Worker         xnn_create_channel_shuffle_nc_x8(
117*4bdc9457SAndroid Build Coastguard Worker           groups(), group_channels(),
118*4bdc9457SAndroid Build Coastguard Worker           input_stride(), output_stride(),
119*4bdc9457SAndroid Build Coastguard Worker           0, &channel_shuffle_op));
120*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, channel_shuffle_op);
121*4bdc9457SAndroid Build Coastguard Worker 
122*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete channel_shuffle_op.
123*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator);
124*4bdc9457SAndroid Build Coastguard Worker 
125*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
126*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_channel_shuffle_nc_x8(
127*4bdc9457SAndroid Build Coastguard Worker           channel_shuffle_op,
128*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
129*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
130*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
131*4bdc9457SAndroid Build Coastguard Worker 
132*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
133*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */));
134*4bdc9457SAndroid Build Coastguard Worker 
135*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
136*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
137*4bdc9457SAndroid Build Coastguard Worker         for (size_t g = 0; g < groups(); g++) {
138*4bdc9457SAndroid Build Coastguard Worker           for (size_t c = 0; c < group_channels(); c++) {
139*4bdc9457SAndroid Build Coastguard Worker             ASSERT_EQ(int32_t(input[i * input_stride() + g * group_channels() + c]),
140*4bdc9457SAndroid Build Coastguard Worker                 int32_t(output[i * output_stride() + c * groups() + g]))
141*4bdc9457SAndroid Build Coastguard Worker               << "batch index " << i << ", group " << g << ", channel " << c;
142*4bdc9457SAndroid Build Coastguard Worker           }
143*4bdc9457SAndroid Build Coastguard Worker         }
144*4bdc9457SAndroid Build Coastguard Worker       }
145*4bdc9457SAndroid Build Coastguard Worker     }
146*4bdc9457SAndroid Build Coastguard Worker   }
147*4bdc9457SAndroid Build Coastguard Worker 
TestX32()148*4bdc9457SAndroid Build Coastguard Worker   void TestX32() const {
149*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
150*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
151*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<uint32_t> u32dist;
152*4bdc9457SAndroid Build Coastguard Worker 
153*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> input(XNN_EXTRA_BYTES / sizeof(uint32_t) + (batch_size() - 1) * input_stride() + channels());
154*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> output((batch_size() - 1) * output_stride() + channels());
155*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
156*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u32dist(rng); });
157*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEAF));
158*4bdc9457SAndroid Build Coastguard Worker 
159*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Channel Shuffle operator.
160*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
161*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t channel_shuffle_op = nullptr;
162*4bdc9457SAndroid Build Coastguard Worker 
163*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
164*4bdc9457SAndroid Build Coastguard Worker         xnn_create_channel_shuffle_nc_x32(
165*4bdc9457SAndroid Build Coastguard Worker           groups(), group_channels(),
166*4bdc9457SAndroid Build Coastguard Worker           input_stride(), output_stride(),
167*4bdc9457SAndroid Build Coastguard Worker           0, &channel_shuffle_op));
168*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, channel_shuffle_op);
169*4bdc9457SAndroid Build Coastguard Worker 
170*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete channel_shuffle_op.
171*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator);
172*4bdc9457SAndroid Build Coastguard Worker 
173*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
174*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_channel_shuffle_nc_x32(
175*4bdc9457SAndroid Build Coastguard Worker           channel_shuffle_op,
176*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
177*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
178*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
179*4bdc9457SAndroid Build Coastguard Worker 
180*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
181*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */));
182*4bdc9457SAndroid Build Coastguard Worker 
183*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
184*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
185*4bdc9457SAndroid Build Coastguard Worker         for (size_t g = 0; g < groups(); g++) {
186*4bdc9457SAndroid Build Coastguard Worker           for (size_t c = 0; c < group_channels(); c++) {
187*4bdc9457SAndroid Build Coastguard Worker             ASSERT_EQ(input[i * input_stride() + g * group_channels() + c],
188*4bdc9457SAndroid Build Coastguard Worker                 output[i * output_stride() + c * groups() + g])
189*4bdc9457SAndroid Build Coastguard Worker               << "batch index " << i << ", group " << g << ", channel " << c;
190*4bdc9457SAndroid Build Coastguard Worker           }
191*4bdc9457SAndroid Build Coastguard Worker         }
192*4bdc9457SAndroid Build Coastguard Worker       }
193*4bdc9457SAndroid Build Coastguard Worker     }
194*4bdc9457SAndroid Build Coastguard Worker   }
195*4bdc9457SAndroid Build Coastguard Worker 
196*4bdc9457SAndroid Build Coastguard Worker  private:
197*4bdc9457SAndroid Build Coastguard Worker   size_t groups_{1};
198*4bdc9457SAndroid Build Coastguard Worker   size_t group_channels_{1};
199*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size_{1};
200*4bdc9457SAndroid Build Coastguard Worker   size_t input_stride_{0};
201*4bdc9457SAndroid Build Coastguard Worker   size_t output_stride_{0};
202*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{15};
203*4bdc9457SAndroid Build Coastguard Worker };
204