1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2020 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 14*4bdc9457SAndroid Build Coastguard Worker #include <limits> 15*4bdc9457SAndroid Build Coastguard Worker #include <random> 16*4bdc9457SAndroid Build Coastguard Worker #include <vector> 17*4bdc9457SAndroid Build Coastguard Worker 18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker class CopyOperatorTester { 22*4bdc9457SAndroid Build Coastguard Worker public: channels(size_t channels)23*4bdc9457SAndroid Build Coastguard Worker inline CopyOperatorTester& channels(size_t channels) { 24*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 25*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 26*4bdc9457SAndroid Build Coastguard Worker return *this; 27*4bdc9457SAndroid Build Coastguard Worker } 28*4bdc9457SAndroid Build Coastguard Worker channels()29*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 30*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 31*4bdc9457SAndroid Build Coastguard Worker } 32*4bdc9457SAndroid Build Coastguard Worker input_stride(size_t input_stride)33*4bdc9457SAndroid Build Coastguard Worker inline CopyOperatorTester& input_stride(size_t input_stride) { 34*4bdc9457SAndroid Build Coastguard Worker assert(input_stride != 0); 35*4bdc9457SAndroid Build Coastguard Worker this->input_stride_ = input_stride; 36*4bdc9457SAndroid Build Coastguard Worker return *this; 37*4bdc9457SAndroid Build Coastguard Worker } 38*4bdc9457SAndroid Build Coastguard Worker input_stride()39*4bdc9457SAndroid Build Coastguard Worker inline size_t input_stride() const { 40*4bdc9457SAndroid Build Coastguard Worker if (this->input_stride_ == 0) { 41*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 42*4bdc9457SAndroid Build Coastguard Worker } else { 43*4bdc9457SAndroid Build Coastguard Worker assert(this->input_stride_ >= this->channels_); 44*4bdc9457SAndroid Build Coastguard Worker return this->input_stride_; 45*4bdc9457SAndroid Build Coastguard Worker } 46*4bdc9457SAndroid Build Coastguard Worker } 47*4bdc9457SAndroid Build Coastguard Worker output_stride(size_t output_stride)48*4bdc9457SAndroid Build Coastguard Worker inline CopyOperatorTester& output_stride(size_t output_stride) { 49*4bdc9457SAndroid Build Coastguard Worker assert(output_stride != 0); 50*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride; 51*4bdc9457SAndroid Build Coastguard Worker return *this; 52*4bdc9457SAndroid Build Coastguard Worker } 53*4bdc9457SAndroid Build Coastguard Worker output_stride()54*4bdc9457SAndroid Build Coastguard Worker inline size_t output_stride() const { 55*4bdc9457SAndroid Build Coastguard Worker if (this->output_stride_ == 0) { 56*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 57*4bdc9457SAndroid Build Coastguard Worker } else { 58*4bdc9457SAndroid Build Coastguard Worker assert(this->output_stride_ >= this->channels_); 59*4bdc9457SAndroid Build Coastguard Worker return this->output_stride_; 60*4bdc9457SAndroid Build Coastguard Worker } 61*4bdc9457SAndroid Build Coastguard Worker } 62*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)63*4bdc9457SAndroid Build Coastguard Worker inline CopyOperatorTester& batch_size(size_t batch_size) { 64*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 65*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 66*4bdc9457SAndroid Build Coastguard Worker return *this; 67*4bdc9457SAndroid Build Coastguard Worker } 68*4bdc9457SAndroid Build Coastguard Worker batch_size()69*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 70*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 71*4bdc9457SAndroid Build Coastguard Worker } 72*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)73*4bdc9457SAndroid Build Coastguard Worker inline CopyOperatorTester& iterations(size_t iterations) { 74*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 75*4bdc9457SAndroid Build Coastguard Worker return *this; 76*4bdc9457SAndroid Build Coastguard Worker } 77*4bdc9457SAndroid Build Coastguard Worker iterations()78*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 79*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 80*4bdc9457SAndroid Build Coastguard Worker } 81*4bdc9457SAndroid Build Coastguard Worker TestX8()82*4bdc9457SAndroid Build Coastguard Worker void TestX8() const { 83*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 84*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 85*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<uint32_t> u8dist( 86*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 87*4bdc9457SAndroid Build Coastguard Worker 88*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + 89*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 90*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels()); 91*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output_ref(batch_size() * channels()); 92*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 93*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 94*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xFA)); 95*4bdc9457SAndroid Build Coastguard Worker 96*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 97*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 98*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 99*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = input[i * input_stride() + c]; 100*4bdc9457SAndroid Build Coastguard Worker } 101*4bdc9457SAndroid Build Coastguard Worker } 102*4bdc9457SAndroid Build Coastguard Worker 103*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Copy operator. 104*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 105*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t copy_op = nullptr; 106*4bdc9457SAndroid Build Coastguard Worker 107*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 108*4bdc9457SAndroid Build Coastguard Worker xnn_create_copy_nc_x8( 109*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 110*4bdc9457SAndroid Build Coastguard Worker 0, ©_op)); 111*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, copy_op); 112*4bdc9457SAndroid Build Coastguard Worker 113*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete copy_op. 114*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_copy_op(copy_op, xnn_delete_operator); 115*4bdc9457SAndroid Build Coastguard Worker 116*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 117*4bdc9457SAndroid Build Coastguard Worker xnn_setup_copy_nc_x8( 118*4bdc9457SAndroid Build Coastguard Worker copy_op, 119*4bdc9457SAndroid Build Coastguard Worker batch_size(), 120*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 121*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 122*4bdc9457SAndroid Build Coastguard Worker 123*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 124*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(copy_op, nullptr /* thread pool */)); 125*4bdc9457SAndroid Build Coastguard Worker 126*4bdc9457SAndroid Build Coastguard Worker // Verify results. 127*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 128*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 129*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 130*4bdc9457SAndroid Build Coastguard Worker << "at batch " << i << " / " << batch_size() << ", channel = " << c << " / " << channels(); 131*4bdc9457SAndroid Build Coastguard Worker } 132*4bdc9457SAndroid Build Coastguard Worker } 133*4bdc9457SAndroid Build Coastguard Worker } 134*4bdc9457SAndroid Build Coastguard Worker } 135*4bdc9457SAndroid Build Coastguard Worker TestX16()136*4bdc9457SAndroid Build Coastguard Worker void TestX16() const { 137*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 138*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 139*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<uint16_t> u16dist; 140*4bdc9457SAndroid Build Coastguard Worker 141*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 142*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 143*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels()); 144*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output_ref(batch_size() * channels()); 145*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 146*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u16dist(rng); }); 147*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0xDEAD)); 148*4bdc9457SAndroid Build Coastguard Worker 149*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 150*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 151*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 152*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = input[i * input_stride() + c]; 153*4bdc9457SAndroid Build Coastguard Worker } 154*4bdc9457SAndroid Build Coastguard Worker } 155*4bdc9457SAndroid Build Coastguard Worker 156*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Copy operator. 157*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 158*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t copy_op = nullptr; 159*4bdc9457SAndroid Build Coastguard Worker 160*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 161*4bdc9457SAndroid Build Coastguard Worker xnn_create_copy_nc_x16( 162*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 163*4bdc9457SAndroid Build Coastguard Worker 0, ©_op)); 164*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, copy_op); 165*4bdc9457SAndroid Build Coastguard Worker 166*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete copy_op. 167*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_copy_op(copy_op, xnn_delete_operator); 168*4bdc9457SAndroid Build Coastguard Worker 169*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 170*4bdc9457SAndroid Build Coastguard Worker xnn_setup_copy_nc_x16( 171*4bdc9457SAndroid Build Coastguard Worker copy_op, 172*4bdc9457SAndroid Build Coastguard Worker batch_size(), 173*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 174*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 175*4bdc9457SAndroid Build Coastguard Worker 176*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 177*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(copy_op, nullptr /* thread pool */)); 178*4bdc9457SAndroid Build Coastguard Worker 179*4bdc9457SAndroid Build Coastguard Worker // Verify results. 180*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 181*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 182*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 183*4bdc9457SAndroid Build Coastguard Worker << "at batch " << i << " / " << batch_size() << ", channel = " << c << " / " << channels(); 184*4bdc9457SAndroid Build Coastguard Worker } 185*4bdc9457SAndroid Build Coastguard Worker } 186*4bdc9457SAndroid Build Coastguard Worker } 187*4bdc9457SAndroid Build Coastguard Worker } 188*4bdc9457SAndroid Build Coastguard Worker TestX32()189*4bdc9457SAndroid Build Coastguard Worker void TestX32() const { 190*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 191*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 192*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<uint32_t> u32dist; 193*4bdc9457SAndroid Build Coastguard Worker 194*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> input(XNN_EXTRA_BYTES / sizeof(uint32_t) + 195*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 196*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> output((batch_size() - 1) * output_stride() + channels()); 197*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> output_ref(batch_size() * channels()); 198*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 199*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u32dist(rng); }); 200*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEEF)); 201*4bdc9457SAndroid Build Coastguard Worker 202*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 203*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 204*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 205*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = input[i * input_stride() + c]; 206*4bdc9457SAndroid Build Coastguard Worker } 207*4bdc9457SAndroid Build Coastguard Worker } 208*4bdc9457SAndroid Build Coastguard Worker 209*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Copy operator. 210*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 211*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t copy_op = nullptr; 212*4bdc9457SAndroid Build Coastguard Worker 213*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 214*4bdc9457SAndroid Build Coastguard Worker xnn_create_copy_nc_x32( 215*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 216*4bdc9457SAndroid Build Coastguard Worker 0, ©_op)); 217*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, copy_op); 218*4bdc9457SAndroid Build Coastguard Worker 219*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete copy_op. 220*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_copy_op(copy_op, xnn_delete_operator); 221*4bdc9457SAndroid Build Coastguard Worker 222*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 223*4bdc9457SAndroid Build Coastguard Worker xnn_setup_copy_nc_x32( 224*4bdc9457SAndroid Build Coastguard Worker copy_op, 225*4bdc9457SAndroid Build Coastguard Worker batch_size(), 226*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 227*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 228*4bdc9457SAndroid Build Coastguard Worker 229*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 230*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(copy_op, nullptr /* thread pool */)); 231*4bdc9457SAndroid Build Coastguard Worker 232*4bdc9457SAndroid Build Coastguard Worker // Verify results. 233*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 234*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 235*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 236*4bdc9457SAndroid Build Coastguard Worker << "at batch " << i << " / " << batch_size() << ", channel = " << c << " / " << channels(); 237*4bdc9457SAndroid Build Coastguard Worker } 238*4bdc9457SAndroid Build Coastguard Worker } 239*4bdc9457SAndroid Build Coastguard Worker } 240*4bdc9457SAndroid Build Coastguard Worker } 241*4bdc9457SAndroid Build Coastguard Worker 242*4bdc9457SAndroid Build Coastguard Worker private: 243*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 244*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 245*4bdc9457SAndroid Build Coastguard Worker size_t input_stride_{0}; 246*4bdc9457SAndroid Build Coastguard Worker size_t output_stride_{0}; 247*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 248*4bdc9457SAndroid Build Coastguard Worker }; 249