xref: /aosp_15_r20/external/XNNPACK/test/channel-shuffle-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <gtest/gtest.h>
12 
13 #include <algorithm>
14 #include <cassert>
15 #include <cstddef>
16 #include <cstdlib>
17 #include <limits>
18 #include <random>
19 #include <vector>
20 
21 #include <xnnpack.h>
22 
23 
24 class ChannelShuffleOperatorTester {
25  public:
groups(size_t groups)26   inline ChannelShuffleOperatorTester& groups(size_t groups) {
27     assert(groups != 0);
28     this->groups_ = groups;
29     return *this;
30   }
31 
groups()32   inline size_t groups() const {
33     return this->groups_;
34   }
35 
group_channels(size_t group_channels)36   inline ChannelShuffleOperatorTester& group_channels(size_t group_channels) {
37     assert(group_channels != 0);
38     this->group_channels_ = group_channels;
39     return *this;
40   }
41 
group_channels()42   inline size_t group_channels() const {
43     return this->group_channels_;
44   }
45 
channels()46   inline size_t channels() const {
47     return groups() * group_channels();
48   }
49 
input_stride(size_t input_stride)50   inline ChannelShuffleOperatorTester& input_stride(size_t input_stride) {
51     assert(input_stride != 0);
52     this->input_stride_ = input_stride;
53     return *this;
54   }
55 
input_stride()56   inline size_t input_stride() const {
57     if (this->input_stride_ == 0) {
58       return channels();
59     } else {
60       assert(this->input_stride_ >= channels());
61       return this->input_stride_;
62     }
63   }
64 
output_stride(size_t output_stride)65   inline ChannelShuffleOperatorTester& output_stride(size_t output_stride) {
66     assert(output_stride != 0);
67     this->output_stride_ = output_stride;
68     return *this;
69   }
70 
output_stride()71   inline size_t output_stride() const {
72     if (this->output_stride_ == 0) {
73       return channels();
74     } else {
75       assert(this->output_stride_ >= channels());
76       return this->output_stride_;
77     }
78   }
79 
batch_size(size_t batch_size)80   inline ChannelShuffleOperatorTester& batch_size(size_t batch_size) {
81     assert(batch_size != 0);
82     this->batch_size_ = batch_size;
83     return *this;
84   }
85 
batch_size()86   inline size_t batch_size() const {
87     return this->batch_size_;
88   }
89 
iterations(size_t iterations)90   inline ChannelShuffleOperatorTester& iterations(size_t iterations) {
91     this->iterations_ = iterations;
92     return *this;
93   }
94 
iterations()95   inline size_t iterations() const {
96     return this->iterations_;
97   }
98 
TestX8()99   void TestX8() const {
100     std::random_device random_device;
101     auto rng = std::mt19937(random_device());
102     std::uniform_int_distribution<int32_t> u8dist(
103       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
104 
105     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + (batch_size() - 1) * input_stride() + channels());
106     std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
107     for (size_t iteration = 0; iteration < iterations(); iteration++) {
108       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
109       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
110 
111       // Create, setup, run, and destroy Channel Shuffle operator.
112       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
113       xnn_operator_t channel_shuffle_op = nullptr;
114 
115       ASSERT_EQ(xnn_status_success,
116         xnn_create_channel_shuffle_nc_x8(
117           groups(), group_channels(),
118           input_stride(), output_stride(),
119           0, &channel_shuffle_op));
120       ASSERT_NE(nullptr, channel_shuffle_op);
121 
122       // Smart pointer to automatically delete channel_shuffle_op.
123       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator);
124 
125       ASSERT_EQ(xnn_status_success,
126         xnn_setup_channel_shuffle_nc_x8(
127           channel_shuffle_op,
128           batch_size(),
129           input.data(), output.data(),
130           nullptr /* thread pool */));
131 
132       ASSERT_EQ(xnn_status_success,
133         xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */));
134 
135       // Verify results.
136       for (size_t i = 0; i < batch_size(); i++) {
137         for (size_t g = 0; g < groups(); g++) {
138           for (size_t c = 0; c < group_channels(); c++) {
139             ASSERT_EQ(int32_t(input[i * input_stride() + g * group_channels() + c]),
140                 int32_t(output[i * output_stride() + c * groups() + g]))
141               << "batch index " << i << ", group " << g << ", channel " << c;
142           }
143         }
144       }
145     }
146   }
147 
TestX32()148   void TestX32() const {
149     std::random_device random_device;
150     auto rng = std::mt19937(random_device());
151     std::uniform_int_distribution<uint32_t> u32dist;
152 
153     std::vector<uint32_t> input(XNN_EXTRA_BYTES / sizeof(uint32_t) + (batch_size() - 1) * input_stride() + channels());
154     std::vector<uint32_t> output((batch_size() - 1) * output_stride() + channels());
155     for (size_t iteration = 0; iteration < iterations(); iteration++) {
156       std::generate(input.begin(), input.end(), [&]() { return u32dist(rng); });
157       std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEAF));
158 
159       // Create, setup, run, and destroy Channel Shuffle operator.
160       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
161       xnn_operator_t channel_shuffle_op = nullptr;
162 
163       ASSERT_EQ(xnn_status_success,
164         xnn_create_channel_shuffle_nc_x32(
165           groups(), group_channels(),
166           input_stride(), output_stride(),
167           0, &channel_shuffle_op));
168       ASSERT_NE(nullptr, channel_shuffle_op);
169 
170       // Smart pointer to automatically delete channel_shuffle_op.
171       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator);
172 
173       ASSERT_EQ(xnn_status_success,
174         xnn_setup_channel_shuffle_nc_x32(
175           channel_shuffle_op,
176           batch_size(),
177           input.data(), output.data(),
178           nullptr /* thread pool */));
179 
180       ASSERT_EQ(xnn_status_success,
181         xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */));
182 
183       // Verify results.
184       for (size_t i = 0; i < batch_size(); i++) {
185         for (size_t g = 0; g < groups(); g++) {
186           for (size_t c = 0; c < group_channels(); c++) {
187             ASSERT_EQ(input[i * input_stride() + g * group_channels() + c],
188                 output[i * output_stride() + c * groups() + g])
189               << "batch index " << i << ", group " << g << ", channel " << c;
190           }
191         }
192       }
193     }
194   }
195 
196  private:
197   size_t groups_{1};
198   size_t group_channels_{1};
199   size_t batch_size_{1};
200   size_t input_stride_{0};
201   size_t output_stride_{0};
202   size_t iterations_{15};
203 };
204