1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 14*4bdc9457SAndroid Build Coastguard Worker #include <functional> 15*4bdc9457SAndroid Build Coastguard Worker #include <random> 16*4bdc9457SAndroid Build Coastguard Worker #include <vector> 17*4bdc9457SAndroid Build Coastguard Worker 18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h> 20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h> 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker class PackMicrokernelTester { 24*4bdc9457SAndroid Build Coastguard Worker public: mr(size_t mr)25*4bdc9457SAndroid Build Coastguard Worker inline PackMicrokernelTester& mr(size_t mr) { 26*4bdc9457SAndroid Build Coastguard Worker assert(mr != 0); 27*4bdc9457SAndroid Build Coastguard Worker this->mr_ = mr; 28*4bdc9457SAndroid Build Coastguard Worker return *this; 29*4bdc9457SAndroid Build Coastguard Worker } 30*4bdc9457SAndroid Build Coastguard Worker mr()31*4bdc9457SAndroid Build Coastguard Worker inline size_t mr() const { 32*4bdc9457SAndroid Build Coastguard Worker return this->mr_; 33*4bdc9457SAndroid Build Coastguard Worker } 34*4bdc9457SAndroid Build Coastguard Worker m(size_t m)35*4bdc9457SAndroid Build Coastguard Worker inline PackMicrokernelTester& m(size_t m) { 36*4bdc9457SAndroid Build Coastguard Worker assert(m != 0); 37*4bdc9457SAndroid Build Coastguard Worker this->m_ = m; 38*4bdc9457SAndroid Build Coastguard Worker return *this; 39*4bdc9457SAndroid Build Coastguard Worker } 40*4bdc9457SAndroid Build Coastguard Worker m()41*4bdc9457SAndroid Build Coastguard Worker inline size_t m() const { 42*4bdc9457SAndroid Build Coastguard Worker return this->m_; 43*4bdc9457SAndroid Build Coastguard Worker } 44*4bdc9457SAndroid Build Coastguard Worker k(size_t k)45*4bdc9457SAndroid Build Coastguard Worker inline PackMicrokernelTester& k(size_t k) { 46*4bdc9457SAndroid Build Coastguard Worker assert(k != 0); 47*4bdc9457SAndroid Build Coastguard Worker this->k_ = k; 48*4bdc9457SAndroid Build Coastguard Worker return *this; 49*4bdc9457SAndroid Build Coastguard Worker } 50*4bdc9457SAndroid Build Coastguard Worker k()51*4bdc9457SAndroid Build Coastguard Worker inline size_t k() const { 52*4bdc9457SAndroid Build Coastguard Worker return this->k_; 53*4bdc9457SAndroid Build Coastguard Worker } 54*4bdc9457SAndroid Build Coastguard Worker x_stride(size_t x_stride)55*4bdc9457SAndroid Build Coastguard Worker inline PackMicrokernelTester& x_stride(size_t x_stride) { 56*4bdc9457SAndroid Build Coastguard Worker assert(x_stride != 0); 57*4bdc9457SAndroid Build Coastguard Worker this->x_stride_ = x_stride; 58*4bdc9457SAndroid Build Coastguard Worker return *this; 59*4bdc9457SAndroid Build Coastguard Worker } 60*4bdc9457SAndroid Build Coastguard Worker x_stride()61*4bdc9457SAndroid Build Coastguard Worker inline size_t x_stride() const { 62*4bdc9457SAndroid Build Coastguard Worker if (this->x_stride_ == 0) { 63*4bdc9457SAndroid Build Coastguard Worker return k(); 64*4bdc9457SAndroid Build Coastguard Worker } else { 65*4bdc9457SAndroid Build Coastguard Worker assert(this->x_stride_ >= k()); 66*4bdc9457SAndroid Build Coastguard Worker return this->x_stride_; 67*4bdc9457SAndroid Build Coastguard Worker } 68*4bdc9457SAndroid Build Coastguard Worker } 69*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)70*4bdc9457SAndroid Build Coastguard Worker inline PackMicrokernelTester& iterations(size_t iterations) { 71*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 72*4bdc9457SAndroid Build Coastguard Worker return *this; 73*4bdc9457SAndroid Build Coastguard Worker } 74*4bdc9457SAndroid Build Coastguard Worker iterations()75*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 76*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 77*4bdc9457SAndroid Build Coastguard Worker } 78*4bdc9457SAndroid Build Coastguard Worker Test(xnn_x32_packx_ukernel_function packx)79*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_x32_packx_ukernel_function packx) const { 80*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 81*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 82*4bdc9457SAndroid Build Coastguard Worker auto u32rng = std::bind(std::uniform_int_distribution<uint32_t>(), rng); 83*4bdc9457SAndroid Build Coastguard Worker 84*4bdc9457SAndroid Build Coastguard Worker const uint32_t c = u32rng(); 85*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> x(k() + (m() - 1) * x_stride() + XNN_EXTRA_BYTES / sizeof(uint32_t)); 86*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t, AlignedAllocator<uint32_t, 64>> y(mr() * k()); 87*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> y_ref(mr() * k()); 88*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 89*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(u32rng)); 90*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), std::ref(u32rng)); 91*4bdc9457SAndroid Build Coastguard Worker 92*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 93*4bdc9457SAndroid Build Coastguard Worker std::fill(y_ref.begin(), y_ref.end(), c); 94*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < mr(); i++) { 95*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < k(); j++) { 96*4bdc9457SAndroid Build Coastguard Worker y_ref[j * mr() + i] = x[std::min(i, m() - 1) * x_stride() + j]; 97*4bdc9457SAndroid Build Coastguard Worker } 98*4bdc9457SAndroid Build Coastguard Worker } 99*4bdc9457SAndroid Build Coastguard Worker 100*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 101*4bdc9457SAndroid Build Coastguard Worker packx( 102*4bdc9457SAndroid Build Coastguard Worker m(), k(), 103*4bdc9457SAndroid Build Coastguard Worker x.data(), x_stride() * sizeof(uint32_t), 104*4bdc9457SAndroid Build Coastguard Worker y.data()); 105*4bdc9457SAndroid Build Coastguard Worker 106*4bdc9457SAndroid Build Coastguard Worker // Verify results. 107*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < mr(); i++) { 108*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < k(); j++) { 109*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y_ref[j * mr() + i], y[j * mr() + i]) 110*4bdc9457SAndroid Build Coastguard Worker << "at pixel = " << i << ", channel = " << j << ", " 111*4bdc9457SAndroid Build Coastguard Worker << "m = " << m() << ", k = " << k(); 112*4bdc9457SAndroid Build Coastguard Worker } 113*4bdc9457SAndroid Build Coastguard Worker } 114*4bdc9457SAndroid Build Coastguard Worker } 115*4bdc9457SAndroid Build Coastguard Worker } 116*4bdc9457SAndroid Build Coastguard Worker 117*4bdc9457SAndroid Build Coastguard Worker private: 118*4bdc9457SAndroid Build Coastguard Worker size_t mr_{1}; 119*4bdc9457SAndroid Build Coastguard Worker size_t m_{1}; 120*4bdc9457SAndroid Build Coastguard Worker size_t k_{1}; 121*4bdc9457SAndroid Build Coastguard Worker size_t x_stride_{0}; 122*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1}; 123*4bdc9457SAndroid Build Coastguard Worker }; 124