1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 15*4bdc9457SAndroid Build Coastguard Worker #include <random> 16*4bdc9457SAndroid Build Coastguard Worker #include <vector> 17*4bdc9457SAndroid Build Coastguard Worker 18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h> 20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h> 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker class WindowMicrokernelTester { 24*4bdc9457SAndroid Build Coastguard Worker public: rows(size_t rows)25*4bdc9457SAndroid Build Coastguard Worker inline WindowMicrokernelTester& rows(size_t rows) { 26*4bdc9457SAndroid Build Coastguard Worker assert(rows != 0); 27*4bdc9457SAndroid Build Coastguard Worker this->rows_ = rows; 28*4bdc9457SAndroid Build Coastguard Worker return *this; 29*4bdc9457SAndroid Build Coastguard Worker } 30*4bdc9457SAndroid Build Coastguard Worker rows()31*4bdc9457SAndroid Build Coastguard Worker inline size_t rows() const { 32*4bdc9457SAndroid Build Coastguard Worker return this->rows_; 33*4bdc9457SAndroid Build Coastguard Worker } 34*4bdc9457SAndroid Build Coastguard Worker batch(size_t batch)35*4bdc9457SAndroid Build Coastguard Worker inline WindowMicrokernelTester& batch(size_t batch) { 36*4bdc9457SAndroid Build Coastguard Worker assert(batch != 0); 37*4bdc9457SAndroid Build Coastguard Worker this->batch_ = batch; 38*4bdc9457SAndroid Build Coastguard Worker return *this; 39*4bdc9457SAndroid Build Coastguard Worker } 40*4bdc9457SAndroid Build Coastguard Worker batch()41*4bdc9457SAndroid Build Coastguard Worker inline size_t batch() const { 42*4bdc9457SAndroid Build Coastguard Worker return this->batch_; 43*4bdc9457SAndroid Build Coastguard Worker } 44*4bdc9457SAndroid Build Coastguard Worker shift(uint32_t shift)45*4bdc9457SAndroid Build Coastguard Worker inline WindowMicrokernelTester& shift(uint32_t shift) { 46*4bdc9457SAndroid Build Coastguard Worker assert(shift < 32); 47*4bdc9457SAndroid Build Coastguard Worker this->shift_ = shift; 48*4bdc9457SAndroid Build Coastguard Worker return *this; 49*4bdc9457SAndroid Build Coastguard Worker } 50*4bdc9457SAndroid Build Coastguard Worker shift()51*4bdc9457SAndroid Build Coastguard Worker inline uint32_t shift() const { 52*4bdc9457SAndroid Build Coastguard Worker return this->shift_; 53*4bdc9457SAndroid Build Coastguard Worker } 54*4bdc9457SAndroid Build Coastguard Worker inplace(bool inplace)55*4bdc9457SAndroid Build Coastguard Worker inline WindowMicrokernelTester& inplace(bool inplace) { 56*4bdc9457SAndroid Build Coastguard Worker this->inplace_ = inplace; 57*4bdc9457SAndroid Build Coastguard Worker return *this; 58*4bdc9457SAndroid Build Coastguard Worker } 59*4bdc9457SAndroid Build Coastguard Worker inplace()60*4bdc9457SAndroid Build Coastguard Worker inline bool inplace() const { 61*4bdc9457SAndroid Build Coastguard Worker return this->inplace_; 62*4bdc9457SAndroid Build Coastguard Worker } 63*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)64*4bdc9457SAndroid Build Coastguard Worker inline WindowMicrokernelTester& iterations(size_t iterations) { 65*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 66*4bdc9457SAndroid Build Coastguard Worker return *this; 67*4bdc9457SAndroid Build Coastguard Worker } 68*4bdc9457SAndroid Build Coastguard Worker iterations()69*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 70*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 71*4bdc9457SAndroid Build Coastguard Worker } 72*4bdc9457SAndroid Build Coastguard Worker Test(xnn_s16_window_ukernel_function window)73*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_s16_window_ukernel_function window) const { 74*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 75*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 76*4bdc9457SAndroid Build Coastguard Worker auto i16rng = std::bind(std::uniform_int_distribution<int16_t>(), std::ref(rng)); 77*4bdc9457SAndroid Build Coastguard Worker 78*4bdc9457SAndroid Build Coastguard Worker std::vector<int16_t> x(batch() * rows() + XNN_EXTRA_BYTES / sizeof(int16_t)); 79*4bdc9457SAndroid Build Coastguard Worker std::vector<int16_t, AlignedAllocator<int16_t, 64>> w(batch() + XNN_EXTRA_BYTES / sizeof(int16_t)); 80*4bdc9457SAndroid Build Coastguard Worker std::vector<int16_t> y(batch() * rows() + (inplace() ? XNN_EXTRA_BYTES / sizeof(int16_t) : 0)); 81*4bdc9457SAndroid Build Coastguard Worker std::vector<int16_t> y_ref(batch() * rows()); 82*4bdc9457SAndroid Build Coastguard Worker const int16_t* x_data = inplace() ? y.data() : x.data(); 83*4bdc9457SAndroid Build Coastguard Worker 84*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 85*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(i16rng)); 86*4bdc9457SAndroid Build Coastguard Worker std::generate(w.begin(), w.end(), std::ref(i16rng)); 87*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), std::ref(i16rng)); 88*4bdc9457SAndroid Build Coastguard Worker std::generate(y_ref.begin(), y_ref.end(), std::ref(i16rng)); 89*4bdc9457SAndroid Build Coastguard Worker 90*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 91*4bdc9457SAndroid Build Coastguard Worker for (size_t m = 0; m < rows(); m++) { 92*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < batch(); n++) { 93*4bdc9457SAndroid Build Coastguard Worker const int16_t x_value = x_data[m * batch() + n]; 94*4bdc9457SAndroid Build Coastguard Worker int32_t value = ((int32_t) x_value * (int32_t) w[n]) >> shift(); 95*4bdc9457SAndroid Build Coastguard Worker value = std::min(value, (int32_t) std::numeric_limits<int16_t>::max()); 96*4bdc9457SAndroid Build Coastguard Worker value = std::max(value, (int32_t) std::numeric_limits<int16_t>::min()); 97*4bdc9457SAndroid Build Coastguard Worker y_ref[m * batch() + n] = value; 98*4bdc9457SAndroid Build Coastguard Worker } 99*4bdc9457SAndroid Build Coastguard Worker } 100*4bdc9457SAndroid Build Coastguard Worker 101*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 102*4bdc9457SAndroid Build Coastguard Worker window(rows(), batch(), x_data, w.data(), y.data(), shift()); 103*4bdc9457SAndroid Build Coastguard Worker 104*4bdc9457SAndroid Build Coastguard Worker // Verify results. 105*4bdc9457SAndroid Build Coastguard Worker for (size_t m = 0; m < rows(); m++) { 106*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < batch(); n++) { 107*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[m * batch() + n], y_ref[m * batch() + n]) 108*4bdc9457SAndroid Build Coastguard Worker << "at row " << m << " / " << rows() 109*4bdc9457SAndroid Build Coastguard Worker << ", shift " << shift() 110*4bdc9457SAndroid Build Coastguard Worker << ", batch " << n << " / " << batch(); 111*4bdc9457SAndroid Build Coastguard Worker } 112*4bdc9457SAndroid Build Coastguard Worker } 113*4bdc9457SAndroid Build Coastguard Worker } 114*4bdc9457SAndroid Build Coastguard Worker } 115*4bdc9457SAndroid Build Coastguard Worker 116*4bdc9457SAndroid Build Coastguard Worker private: 117*4bdc9457SAndroid Build Coastguard Worker size_t rows_{1}; 118*4bdc9457SAndroid Build Coastguard Worker size_t batch_{1}; 119*4bdc9457SAndroid Build Coastguard Worker uint32_t shift_{12}; 120*4bdc9457SAndroid Build Coastguard Worker bool inplace_{false}; 121*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 122*4bdc9457SAndroid Build Coastguard Worker }; 123