1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2021 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 15*4bdc9457SAndroid Build Coastguard Worker #include <functional> 16*4bdc9457SAndroid Build Coastguard Worker #include <limits> 17*4bdc9457SAndroid Build Coastguard Worker #include <random> 18*4bdc9457SAndroid Build Coastguard Worker #include <vector> 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker class TanhOperatorTester { 24*4bdc9457SAndroid Build Coastguard Worker public: channels(size_t channels)25*4bdc9457SAndroid Build Coastguard Worker inline TanhOperatorTester& channels(size_t channels) { 26*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 27*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 28*4bdc9457SAndroid Build Coastguard Worker return *this; 29*4bdc9457SAndroid Build Coastguard Worker } 30*4bdc9457SAndroid Build Coastguard Worker channels()31*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 32*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 33*4bdc9457SAndroid Build Coastguard Worker } 34*4bdc9457SAndroid Build Coastguard Worker input_stride(size_t input_stride)35*4bdc9457SAndroid Build Coastguard Worker inline TanhOperatorTester& input_stride(size_t input_stride) { 36*4bdc9457SAndroid Build Coastguard Worker assert(input_stride != 0); 37*4bdc9457SAndroid Build Coastguard Worker this->input_stride_ = input_stride; 38*4bdc9457SAndroid Build Coastguard Worker return *this; 39*4bdc9457SAndroid Build Coastguard Worker } 40*4bdc9457SAndroid Build Coastguard Worker input_stride()41*4bdc9457SAndroid Build Coastguard Worker inline size_t input_stride() const { 42*4bdc9457SAndroid Build Coastguard Worker if (this->input_stride_ == 0) { 43*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 44*4bdc9457SAndroid Build Coastguard Worker } else { 45*4bdc9457SAndroid Build Coastguard Worker assert(this->input_stride_ >= this->channels_); 46*4bdc9457SAndroid Build Coastguard Worker return this->input_stride_; 47*4bdc9457SAndroid Build Coastguard Worker } 48*4bdc9457SAndroid Build Coastguard Worker } 49*4bdc9457SAndroid Build Coastguard Worker output_stride(size_t output_stride)50*4bdc9457SAndroid Build Coastguard Worker inline TanhOperatorTester& output_stride(size_t output_stride) { 51*4bdc9457SAndroid Build Coastguard Worker assert(output_stride != 0); 52*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride; 53*4bdc9457SAndroid Build Coastguard Worker return *this; 54*4bdc9457SAndroid Build Coastguard Worker } 55*4bdc9457SAndroid Build Coastguard Worker output_stride()56*4bdc9457SAndroid Build Coastguard Worker inline size_t output_stride() const { 57*4bdc9457SAndroid Build Coastguard Worker if (this->output_stride_ == 0) { 58*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 59*4bdc9457SAndroid Build Coastguard Worker } else { 60*4bdc9457SAndroid Build Coastguard Worker assert(this->output_stride_ >= this->channels_); 61*4bdc9457SAndroid Build Coastguard Worker return this->output_stride_; 62*4bdc9457SAndroid Build Coastguard Worker } 63*4bdc9457SAndroid Build Coastguard Worker } 64*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)65*4bdc9457SAndroid Build Coastguard Worker inline TanhOperatorTester& batch_size(size_t batch_size) { 66*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 67*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 68*4bdc9457SAndroid Build Coastguard Worker return *this; 69*4bdc9457SAndroid Build Coastguard Worker } 70*4bdc9457SAndroid Build Coastguard Worker batch_size()71*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 72*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 73*4bdc9457SAndroid Build Coastguard Worker } 74*4bdc9457SAndroid Build Coastguard Worker input_scale(float input_scale)75*4bdc9457SAndroid Build Coastguard Worker inline TanhOperatorTester& input_scale(float input_scale) { 76*4bdc9457SAndroid Build Coastguard Worker assert(input_scale > 0.0f); 77*4bdc9457SAndroid Build Coastguard Worker assert(std::isnormal(input_scale)); 78*4bdc9457SAndroid Build Coastguard Worker this->input_scale_ = input_scale; 79*4bdc9457SAndroid Build Coastguard Worker return *this; 80*4bdc9457SAndroid Build Coastguard Worker } 81*4bdc9457SAndroid Build Coastguard Worker input_scale()82*4bdc9457SAndroid Build Coastguard Worker inline float input_scale() const { 83*4bdc9457SAndroid Build Coastguard Worker return this->input_scale_; 84*4bdc9457SAndroid Build Coastguard Worker } 85*4bdc9457SAndroid Build Coastguard Worker input_zero_point(uint8_t input_zero_point)86*4bdc9457SAndroid Build Coastguard Worker inline TanhOperatorTester& input_zero_point(uint8_t input_zero_point) { 87*4bdc9457SAndroid Build Coastguard Worker this->input_zero_point_ = input_zero_point; 88*4bdc9457SAndroid Build Coastguard Worker return *this; 89*4bdc9457SAndroid Build Coastguard Worker } 90*4bdc9457SAndroid Build Coastguard Worker input_zero_point()91*4bdc9457SAndroid Build Coastguard Worker inline uint8_t input_zero_point() const { 92*4bdc9457SAndroid Build Coastguard Worker return this->input_zero_point_; 93*4bdc9457SAndroid Build Coastguard Worker } 94*4bdc9457SAndroid Build Coastguard Worker output_scale()95*4bdc9457SAndroid Build Coastguard Worker inline float output_scale() const { 96*4bdc9457SAndroid Build Coastguard Worker return 1.0f / 128.0f; 97*4bdc9457SAndroid Build Coastguard Worker } 98*4bdc9457SAndroid Build Coastguard Worker output_zero_point()99*4bdc9457SAndroid Build Coastguard Worker inline uint8_t output_zero_point() const { 100*4bdc9457SAndroid Build Coastguard Worker return 128; 101*4bdc9457SAndroid Build Coastguard Worker } 102*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)103*4bdc9457SAndroid Build Coastguard Worker inline TanhOperatorTester& qmin(uint8_t qmin) { 104*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 105*4bdc9457SAndroid Build Coastguard Worker return *this; 106*4bdc9457SAndroid Build Coastguard Worker } 107*4bdc9457SAndroid Build Coastguard Worker qmin()108*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 109*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 110*4bdc9457SAndroid Build Coastguard Worker } 111*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)112*4bdc9457SAndroid Build Coastguard Worker inline TanhOperatorTester& qmax(uint8_t qmax) { 113*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 114*4bdc9457SAndroid Build Coastguard Worker return *this; 115*4bdc9457SAndroid Build Coastguard Worker } 116*4bdc9457SAndroid Build Coastguard Worker qmax()117*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 118*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 119*4bdc9457SAndroid Build Coastguard Worker } 120*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)121*4bdc9457SAndroid Build Coastguard Worker inline TanhOperatorTester& iterations(size_t iterations) { 122*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 123*4bdc9457SAndroid Build Coastguard Worker return *this; 124*4bdc9457SAndroid Build Coastguard Worker } 125*4bdc9457SAndroid Build Coastguard Worker iterations()126*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 127*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 128*4bdc9457SAndroid Build Coastguard Worker } 129*4bdc9457SAndroid Build Coastguard Worker TestQS8()130*4bdc9457SAndroid Build Coastguard Worker void TestQS8() const { 131*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 132*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 133*4bdc9457SAndroid Build Coastguard Worker auto i8rng = std::bind( 134*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()), 135*4bdc9457SAndroid Build Coastguard Worker std::ref(rng)); 136*4bdc9457SAndroid Build Coastguard Worker 137*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t)); 138*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output((batch_size() - 1) * output_stride() + channels()); 139*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 140*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 141*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), std::ref(i8rng)); 142*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), 0xA5); 143*4bdc9457SAndroid Build Coastguard Worker 144*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 145*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 146*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 147*4bdc9457SAndroid Build Coastguard Worker const float x = input_scale() * 148*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point() - 0x80)); 149*4bdc9457SAndroid Build Coastguard Worker const float tanh_x = std::tanh(x); 150*4bdc9457SAndroid Build Coastguard Worker const float scaled_tanh_x = tanh_x / output_scale(); 151*4bdc9457SAndroid Build Coastguard Worker float y = scaled_tanh_x; 152*4bdc9457SAndroid Build Coastguard Worker y = std::min<float>(y, int32_t(qmax() - 0x80) - int32_t(output_zero_point() - 0x80)); 153*4bdc9457SAndroid Build Coastguard Worker y = std::max<float>(y, int32_t(qmin() - 0x80) - int32_t(output_zero_point() - 0x80)); 154*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = y + int32_t(output_zero_point() - 0x80); 155*4bdc9457SAndroid Build Coastguard Worker } 156*4bdc9457SAndroid Build Coastguard Worker } 157*4bdc9457SAndroid Build Coastguard Worker 158*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Sigmoid operator. 159*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 160*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t tanh_op = nullptr; 161*4bdc9457SAndroid Build Coastguard Worker 162*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 163*4bdc9457SAndroid Build Coastguard Worker xnn_create_tanh_nc_qs8( 164*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 165*4bdc9457SAndroid Build Coastguard Worker int8_t(input_zero_point() - 0x80), input_scale(), 166*4bdc9457SAndroid Build Coastguard Worker int8_t(output_zero_point() - 0x80), output_scale(), 167*4bdc9457SAndroid Build Coastguard Worker int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 168*4bdc9457SAndroid Build Coastguard Worker 0, &tanh_op)); 169*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, tanh_op); 170*4bdc9457SAndroid Build Coastguard Worker 171*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete tanh_op. 172*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_tanh_op(tanh_op, xnn_delete_operator); 173*4bdc9457SAndroid Build Coastguard Worker 174*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 175*4bdc9457SAndroid Build Coastguard Worker xnn_setup_tanh_nc_qs8( 176*4bdc9457SAndroid Build Coastguard Worker tanh_op, 177*4bdc9457SAndroid Build Coastguard Worker batch_size(), 178*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 179*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 180*4bdc9457SAndroid Build Coastguard Worker 181*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 182*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(tanh_op, nullptr /* thread pool */)); 183*4bdc9457SAndroid Build Coastguard Worker 184*4bdc9457SAndroid Build Coastguard Worker // Verify results. 185*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 186*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 187*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f); 188*4bdc9457SAndroid Build Coastguard Worker } 189*4bdc9457SAndroid Build Coastguard Worker } 190*4bdc9457SAndroid Build Coastguard Worker } 191*4bdc9457SAndroid Build Coastguard Worker } 192*4bdc9457SAndroid Build Coastguard Worker TestQU8()193*4bdc9457SAndroid Build Coastguard Worker void TestQU8() const { 194*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 195*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 196*4bdc9457SAndroid Build Coastguard Worker auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng); 197*4bdc9457SAndroid Build Coastguard Worker 198*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 199*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels()); 200*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 201*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 202*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), std::ref(u8rng)); 203*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), 0xA5); 204*4bdc9457SAndroid Build Coastguard Worker 205*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 206*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 207*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 208*4bdc9457SAndroid Build Coastguard Worker const float x = input_scale() * 209*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point())); 210*4bdc9457SAndroid Build Coastguard Worker const float tanh_x = std::tanh(x); 211*4bdc9457SAndroid Build Coastguard Worker const float scaled_tanh_x = tanh_x / output_scale(); 212*4bdc9457SAndroid Build Coastguard Worker float y = scaled_tanh_x; 213*4bdc9457SAndroid Build Coastguard Worker y = std::min<float>(y, int32_t(qmax()) - int32_t(output_zero_point())); 214*4bdc9457SAndroid Build Coastguard Worker y = std::max<float>(y, int32_t(qmin()) - int32_t(output_zero_point())); 215*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = y + int32_t(output_zero_point()); 216*4bdc9457SAndroid Build Coastguard Worker } 217*4bdc9457SAndroid Build Coastguard Worker } 218*4bdc9457SAndroid Build Coastguard Worker 219*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Sigmoid operator. 220*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 221*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t tanh_op = nullptr; 222*4bdc9457SAndroid Build Coastguard Worker 223*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 224*4bdc9457SAndroid Build Coastguard Worker xnn_create_tanh_nc_qu8( 225*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 226*4bdc9457SAndroid Build Coastguard Worker input_zero_point(), input_scale(), 227*4bdc9457SAndroid Build Coastguard Worker output_zero_point(), output_scale(), 228*4bdc9457SAndroid Build Coastguard Worker qmin(), qmax(), 229*4bdc9457SAndroid Build Coastguard Worker 0, &tanh_op)); 230*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, tanh_op); 231*4bdc9457SAndroid Build Coastguard Worker 232*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete tanh_op. 233*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_tanh_op(tanh_op, xnn_delete_operator); 234*4bdc9457SAndroid Build Coastguard Worker 235*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 236*4bdc9457SAndroid Build Coastguard Worker xnn_setup_tanh_nc_qu8( 237*4bdc9457SAndroid Build Coastguard Worker tanh_op, 238*4bdc9457SAndroid Build Coastguard Worker batch_size(), 239*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 240*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 241*4bdc9457SAndroid Build Coastguard Worker 242*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 243*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(tanh_op, nullptr /* thread pool */)); 244*4bdc9457SAndroid Build Coastguard Worker 245*4bdc9457SAndroid Build Coastguard Worker // Verify results. 246*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 247*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 248*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f); 249*4bdc9457SAndroid Build Coastguard Worker } 250*4bdc9457SAndroid Build Coastguard Worker } 251*4bdc9457SAndroid Build Coastguard Worker } 252*4bdc9457SAndroid Build Coastguard Worker } 253*4bdc9457SAndroid Build Coastguard Worker 254*4bdc9457SAndroid Build Coastguard Worker private: 255*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 256*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 257*4bdc9457SAndroid Build Coastguard Worker size_t input_stride_{0}; 258*4bdc9457SAndroid Build Coastguard Worker size_t output_stride_{0}; 259*4bdc9457SAndroid Build Coastguard Worker float input_scale_{0.75f}; 260*4bdc9457SAndroid Build Coastguard Worker uint8_t input_zero_point_{121}; 261*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 262*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 263*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 264*4bdc9457SAndroid Build Coastguard Worker }; 265