1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <gtest/gtest.h> 12 13 #include <algorithm> 14 #include <cassert> 15 #include <cstddef> 16 #include <cstdlib> 17 #include <limits> 18 #include <random> 19 #include <vector> 20 21 #include <xnnpack.h> 22 23 24 class ChannelShuffleOperatorTester { 25 public: groups(size_t groups)26 inline ChannelShuffleOperatorTester& groups(size_t groups) { 27 assert(groups != 0); 28 this->groups_ = groups; 29 return *this; 30 } 31 groups()32 inline size_t groups() const { 33 return this->groups_; 34 } 35 group_channels(size_t group_channels)36 inline ChannelShuffleOperatorTester& group_channels(size_t group_channels) { 37 assert(group_channels != 0); 38 this->group_channels_ = group_channels; 39 return *this; 40 } 41 group_channels()42 inline size_t group_channels() const { 43 return this->group_channels_; 44 } 45 channels()46 inline size_t channels() const { 47 return groups() * group_channels(); 48 } 49 input_stride(size_t input_stride)50 inline ChannelShuffleOperatorTester& input_stride(size_t input_stride) { 51 assert(input_stride != 0); 52 this->input_stride_ = input_stride; 53 return *this; 54 } 55 input_stride()56 inline size_t input_stride() const { 57 if (this->input_stride_ == 0) { 58 return channels(); 59 } else { 60 assert(this->input_stride_ >= channels()); 61 return this->input_stride_; 62 } 63 } 64 output_stride(size_t output_stride)65 inline ChannelShuffleOperatorTester& output_stride(size_t output_stride) { 66 assert(output_stride != 0); 67 this->output_stride_ = output_stride; 68 return *this; 69 } 70 output_stride()71 inline size_t output_stride() const { 72 if (this->output_stride_ == 0) { 73 return channels(); 74 } else { 75 assert(this->output_stride_ >= channels()); 76 return this->output_stride_; 77 } 78 } 79 batch_size(size_t batch_size)80 inline ChannelShuffleOperatorTester& batch_size(size_t batch_size) { 81 assert(batch_size != 0); 82 this->batch_size_ = batch_size; 83 return *this; 84 } 85 batch_size()86 inline size_t batch_size() const { 87 return this->batch_size_; 88 } 89 iterations(size_t iterations)90 inline ChannelShuffleOperatorTester& iterations(size_t iterations) { 91 this->iterations_ = iterations; 92 return *this; 93 } 94 iterations()95 inline size_t iterations() const { 96 return this->iterations_; 97 } 98 TestX8()99 void TestX8() const { 100 std::random_device random_device; 101 auto rng = std::mt19937(random_device()); 102 std::uniform_int_distribution<int32_t> u8dist( 103 std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 104 105 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + (batch_size() - 1) * input_stride() + channels()); 106 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels()); 107 for (size_t iteration = 0; iteration < iterations(); iteration++) { 108 std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 109 std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 110 111 // Create, setup, run, and destroy Channel Shuffle operator. 112 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 113 xnn_operator_t channel_shuffle_op = nullptr; 114 115 ASSERT_EQ(xnn_status_success, 116 xnn_create_channel_shuffle_nc_x8( 117 groups(), group_channels(), 118 input_stride(), output_stride(), 119 0, &channel_shuffle_op)); 120 ASSERT_NE(nullptr, channel_shuffle_op); 121 122 // Smart pointer to automatically delete channel_shuffle_op. 123 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator); 124 125 ASSERT_EQ(xnn_status_success, 126 xnn_setup_channel_shuffle_nc_x8( 127 channel_shuffle_op, 128 batch_size(), 129 input.data(), output.data(), 130 nullptr /* thread pool */)); 131 132 ASSERT_EQ(xnn_status_success, 133 xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */)); 134 135 // Verify results. 136 for (size_t i = 0; i < batch_size(); i++) { 137 for (size_t g = 0; g < groups(); g++) { 138 for (size_t c = 0; c < group_channels(); c++) { 139 ASSERT_EQ(int32_t(input[i * input_stride() + g * group_channels() + c]), 140 int32_t(output[i * output_stride() + c * groups() + g])) 141 << "batch index " << i << ", group " << g << ", channel " << c; 142 } 143 } 144 } 145 } 146 } 147 TestX32()148 void TestX32() const { 149 std::random_device random_device; 150 auto rng = std::mt19937(random_device()); 151 std::uniform_int_distribution<uint32_t> u32dist; 152 153 std::vector<uint32_t> input(XNN_EXTRA_BYTES / sizeof(uint32_t) + (batch_size() - 1) * input_stride() + channels()); 154 std::vector<uint32_t> output((batch_size() - 1) * output_stride() + channels()); 155 for (size_t iteration = 0; iteration < iterations(); iteration++) { 156 std::generate(input.begin(), input.end(), [&]() { return u32dist(rng); }); 157 std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEAF)); 158 159 // Create, setup, run, and destroy Channel Shuffle operator. 160 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 161 xnn_operator_t channel_shuffle_op = nullptr; 162 163 ASSERT_EQ(xnn_status_success, 164 xnn_create_channel_shuffle_nc_x32( 165 groups(), group_channels(), 166 input_stride(), output_stride(), 167 0, &channel_shuffle_op)); 168 ASSERT_NE(nullptr, channel_shuffle_op); 169 170 // Smart pointer to automatically delete channel_shuffle_op. 171 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator); 172 173 ASSERT_EQ(xnn_status_success, 174 xnn_setup_channel_shuffle_nc_x32( 175 channel_shuffle_op, 176 batch_size(), 177 input.data(), output.data(), 178 nullptr /* thread pool */)); 179 180 ASSERT_EQ(xnn_status_success, 181 xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */)); 182 183 // Verify results. 184 for (size_t i = 0; i < batch_size(); i++) { 185 for (size_t g = 0; g < groups(); g++) { 186 for (size_t c = 0; c < group_channels(); c++) { 187 ASSERT_EQ(input[i * input_stride() + g * group_channels() + c], 188 output[i * output_stride() + c * groups() + g]) 189 << "batch index " << i << ", group " << g << ", channel " << c; 190 } 191 } 192 } 193 } 194 } 195 196 private: 197 size_t groups_{1}; 198 size_t group_channels_{1}; 199 size_t batch_size_{1}; 200 size_t input_stride_{0}; 201 size_t output_stride_{0}; 202 size_t iterations_{15}; 203 }; 204