1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 14*4bdc9457SAndroid Build Coastguard Worker #include <random> 15*4bdc9457SAndroid Build Coastguard Worker #include <vector> 16*4bdc9457SAndroid Build Coastguard Worker 17*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker class HardSwishOperatorTester { 23*4bdc9457SAndroid Build Coastguard Worker public: channels(size_t channels)24*4bdc9457SAndroid Build Coastguard Worker inline HardSwishOperatorTester& channels(size_t channels) { 25*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 26*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 27*4bdc9457SAndroid Build Coastguard Worker return *this; 28*4bdc9457SAndroid Build Coastguard Worker } 29*4bdc9457SAndroid Build Coastguard Worker channels()30*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 31*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 32*4bdc9457SAndroid Build Coastguard Worker } 33*4bdc9457SAndroid Build Coastguard Worker input_stride(size_t input_stride)34*4bdc9457SAndroid Build Coastguard Worker inline HardSwishOperatorTester& input_stride(size_t input_stride) { 35*4bdc9457SAndroid Build Coastguard Worker assert(input_stride != 0); 36*4bdc9457SAndroid Build Coastguard Worker this->input_stride_ = input_stride; 37*4bdc9457SAndroid Build Coastguard Worker return *this; 38*4bdc9457SAndroid Build Coastguard Worker } 39*4bdc9457SAndroid Build Coastguard Worker input_stride()40*4bdc9457SAndroid Build Coastguard Worker inline size_t input_stride() const { 41*4bdc9457SAndroid Build Coastguard Worker if (this->input_stride_ == 0) { 42*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 43*4bdc9457SAndroid Build Coastguard Worker } else { 44*4bdc9457SAndroid Build Coastguard Worker assert(this->input_stride_ >= this->channels_); 45*4bdc9457SAndroid Build Coastguard Worker return this->input_stride_; 46*4bdc9457SAndroid Build Coastguard Worker } 47*4bdc9457SAndroid Build Coastguard Worker } 48*4bdc9457SAndroid Build Coastguard Worker output_stride(size_t output_stride)49*4bdc9457SAndroid Build Coastguard Worker inline HardSwishOperatorTester& output_stride(size_t output_stride) { 50*4bdc9457SAndroid Build Coastguard Worker assert(output_stride != 0); 51*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride; 52*4bdc9457SAndroid Build Coastguard Worker return *this; 53*4bdc9457SAndroid Build Coastguard Worker } 54*4bdc9457SAndroid Build Coastguard Worker output_stride()55*4bdc9457SAndroid Build Coastguard Worker inline size_t output_stride() const { 56*4bdc9457SAndroid Build Coastguard Worker if (this->output_stride_ == 0) { 57*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 58*4bdc9457SAndroid Build Coastguard Worker } else { 59*4bdc9457SAndroid Build Coastguard Worker assert(this->output_stride_ >= this->channels_); 60*4bdc9457SAndroid Build Coastguard Worker return this->output_stride_; 61*4bdc9457SAndroid Build Coastguard Worker } 62*4bdc9457SAndroid Build Coastguard Worker } 63*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)64*4bdc9457SAndroid Build Coastguard Worker inline HardSwishOperatorTester& batch_size(size_t batch_size) { 65*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 66*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 67*4bdc9457SAndroid Build Coastguard Worker return *this; 68*4bdc9457SAndroid Build Coastguard Worker } 69*4bdc9457SAndroid Build Coastguard Worker batch_size()70*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 71*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 72*4bdc9457SAndroid Build Coastguard Worker } 73*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)74*4bdc9457SAndroid Build Coastguard Worker inline HardSwishOperatorTester& iterations(size_t iterations) { 75*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 76*4bdc9457SAndroid Build Coastguard Worker return *this; 77*4bdc9457SAndroid Build Coastguard Worker } 78*4bdc9457SAndroid Build Coastguard Worker iterations()79*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 80*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 81*4bdc9457SAndroid Build Coastguard Worker } 82*4bdc9457SAndroid Build Coastguard Worker TestF16()83*4bdc9457SAndroid Build Coastguard Worker void TestF16() const { 84*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 85*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 86*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-4.0f, 4.0f); 87*4bdc9457SAndroid Build Coastguard Worker 88*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 89*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 90*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels()); 91*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 92*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 93*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 94*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 95*4bdc9457SAndroid Build Coastguard Worker 96*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 97*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 98*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 99*4bdc9457SAndroid Build Coastguard Worker const float x = fp16_ieee_to_fp32_value(input[i * input_stride() + c]); 100*4bdc9457SAndroid Build Coastguard Worker const float y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f; 101*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = y; 102*4bdc9457SAndroid Build Coastguard Worker } 103*4bdc9457SAndroid Build Coastguard Worker } 104*4bdc9457SAndroid Build Coastguard Worker 105*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy HardSwish operator. 106*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 107*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t hardswish_op = nullptr; 108*4bdc9457SAndroid Build Coastguard Worker 109*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_hardswish_nc_f16( 110*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 111*4bdc9457SAndroid Build Coastguard Worker 0, &hardswish_op); 112*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 113*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 114*4bdc9457SAndroid Build Coastguard Worker } 115*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 116*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, hardswish_op); 117*4bdc9457SAndroid Build Coastguard Worker 118*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete hardswish_op. 119*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_hardswish_op(hardswish_op, xnn_delete_operator); 120*4bdc9457SAndroid Build Coastguard Worker 121*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 122*4bdc9457SAndroid Build Coastguard Worker xnn_setup_hardswish_nc_f16( 123*4bdc9457SAndroid Build Coastguard Worker hardswish_op, 124*4bdc9457SAndroid Build Coastguard Worker batch_size(), 125*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 126*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 127*4bdc9457SAndroid Build Coastguard Worker 128*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 129*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(hardswish_op, nullptr /* thread pool */)); 130*4bdc9457SAndroid Build Coastguard Worker 131*4bdc9457SAndroid Build Coastguard Worker // Verify results. 132*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 133*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 134*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_ref[i * channels() + c], std::max(1.0e-3f, std::abs(output_ref[i * channels() + c]) * 1.0e-2f)) 135*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels(); 136*4bdc9457SAndroid Build Coastguard Worker } 137*4bdc9457SAndroid Build Coastguard Worker } 138*4bdc9457SAndroid Build Coastguard Worker } 139*4bdc9457SAndroid Build Coastguard Worker } 140*4bdc9457SAndroid Build Coastguard Worker TestF32()141*4bdc9457SAndroid Build Coastguard Worker void TestF32() const { 142*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 143*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 144*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-4.0f, 4.0f); 145*4bdc9457SAndroid Build Coastguard Worker 146*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 147*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 148*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() - 1) * output_stride() + channels()); 149*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 150*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 151*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 152*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 153*4bdc9457SAndroid Build Coastguard Worker 154*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 155*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 156*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 157*4bdc9457SAndroid Build Coastguard Worker const float x = input[i * input_stride() + c]; 158*4bdc9457SAndroid Build Coastguard Worker const float y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f; 159*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = y; 160*4bdc9457SAndroid Build Coastguard Worker } 161*4bdc9457SAndroid Build Coastguard Worker } 162*4bdc9457SAndroid Build Coastguard Worker 163*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy HardSwish operator. 164*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 165*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t hardswish_op = nullptr; 166*4bdc9457SAndroid Build Coastguard Worker 167*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 168*4bdc9457SAndroid Build Coastguard Worker xnn_create_hardswish_nc_f32( 169*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 170*4bdc9457SAndroid Build Coastguard Worker 0, &hardswish_op)); 171*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, hardswish_op); 172*4bdc9457SAndroid Build Coastguard Worker 173*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete hardswish_op. 174*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_hardswish_op(hardswish_op, xnn_delete_operator); 175*4bdc9457SAndroid Build Coastguard Worker 176*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 177*4bdc9457SAndroid Build Coastguard Worker xnn_setup_hardswish_nc_f32( 178*4bdc9457SAndroid Build Coastguard Worker hardswish_op, 179*4bdc9457SAndroid Build Coastguard Worker batch_size(), 180*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 181*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 182*4bdc9457SAndroid Build Coastguard Worker 183*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 184*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(hardswish_op, nullptr /* thread pool */)); 185*4bdc9457SAndroid Build Coastguard Worker 186*4bdc9457SAndroid Build Coastguard Worker // Verify results. 187*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 188*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 189*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output_ref[i * channels() + c], output[i * output_stride() + c], std::max(1.0e-7f, std::abs(output[i * output_stride() + c]) * 1.0e-6f)) 190*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels(); 191*4bdc9457SAndroid Build Coastguard Worker } 192*4bdc9457SAndroid Build Coastguard Worker } 193*4bdc9457SAndroid Build Coastguard Worker } 194*4bdc9457SAndroid Build Coastguard Worker } 195*4bdc9457SAndroid Build Coastguard Worker 196*4bdc9457SAndroid Build Coastguard Worker private: 197*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 198*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 199*4bdc9457SAndroid Build Coastguard Worker size_t input_stride_{0}; 200*4bdc9457SAndroid Build Coastguard Worker size_t output_stride_{0}; 201*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 202*4bdc9457SAndroid Build Coastguard Worker }; 203