xref: /aosp_15_r20/external/XNNPACK/test/pack-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
14*4bdc9457SAndroid Build Coastguard Worker #include <functional>
15*4bdc9457SAndroid Build Coastguard Worker #include <random>
16*4bdc9457SAndroid Build Coastguard Worker #include <vector>
17*4bdc9457SAndroid Build Coastguard Worker 
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker 
23*4bdc9457SAndroid Build Coastguard Worker class PackMicrokernelTester {
24*4bdc9457SAndroid Build Coastguard Worker  public:
mr(size_t mr)25*4bdc9457SAndroid Build Coastguard Worker   inline PackMicrokernelTester& mr(size_t mr) {
26*4bdc9457SAndroid Build Coastguard Worker     assert(mr != 0);
27*4bdc9457SAndroid Build Coastguard Worker     this->mr_ = mr;
28*4bdc9457SAndroid Build Coastguard Worker     return *this;
29*4bdc9457SAndroid Build Coastguard Worker   }
30*4bdc9457SAndroid Build Coastguard Worker 
mr()31*4bdc9457SAndroid Build Coastguard Worker   inline size_t mr() const {
32*4bdc9457SAndroid Build Coastguard Worker     return this->mr_;
33*4bdc9457SAndroid Build Coastguard Worker   }
34*4bdc9457SAndroid Build Coastguard Worker 
m(size_t m)35*4bdc9457SAndroid Build Coastguard Worker   inline PackMicrokernelTester& m(size_t m) {
36*4bdc9457SAndroid Build Coastguard Worker     assert(m != 0);
37*4bdc9457SAndroid Build Coastguard Worker     this->m_ = m;
38*4bdc9457SAndroid Build Coastguard Worker     return *this;
39*4bdc9457SAndroid Build Coastguard Worker   }
40*4bdc9457SAndroid Build Coastguard Worker 
m()41*4bdc9457SAndroid Build Coastguard Worker   inline size_t m() const {
42*4bdc9457SAndroid Build Coastguard Worker     return this->m_;
43*4bdc9457SAndroid Build Coastguard Worker   }
44*4bdc9457SAndroid Build Coastguard Worker 
k(size_t k)45*4bdc9457SAndroid Build Coastguard Worker   inline PackMicrokernelTester& k(size_t k) {
46*4bdc9457SAndroid Build Coastguard Worker     assert(k != 0);
47*4bdc9457SAndroid Build Coastguard Worker     this->k_ = k;
48*4bdc9457SAndroid Build Coastguard Worker     return *this;
49*4bdc9457SAndroid Build Coastguard Worker   }
50*4bdc9457SAndroid Build Coastguard Worker 
k()51*4bdc9457SAndroid Build Coastguard Worker   inline size_t k() const {
52*4bdc9457SAndroid Build Coastguard Worker     return this->k_;
53*4bdc9457SAndroid Build Coastguard Worker   }
54*4bdc9457SAndroid Build Coastguard Worker 
x_stride(size_t x_stride)55*4bdc9457SAndroid Build Coastguard Worker   inline PackMicrokernelTester& x_stride(size_t x_stride) {
56*4bdc9457SAndroid Build Coastguard Worker     assert(x_stride != 0);
57*4bdc9457SAndroid Build Coastguard Worker     this->x_stride_ = x_stride;
58*4bdc9457SAndroid Build Coastguard Worker     return *this;
59*4bdc9457SAndroid Build Coastguard Worker   }
60*4bdc9457SAndroid Build Coastguard Worker 
x_stride()61*4bdc9457SAndroid Build Coastguard Worker   inline size_t x_stride() const {
62*4bdc9457SAndroid Build Coastguard Worker     if (this->x_stride_ == 0) {
63*4bdc9457SAndroid Build Coastguard Worker       return k();
64*4bdc9457SAndroid Build Coastguard Worker     } else {
65*4bdc9457SAndroid Build Coastguard Worker       assert(this->x_stride_ >= k());
66*4bdc9457SAndroid Build Coastguard Worker       return this->x_stride_;
67*4bdc9457SAndroid Build Coastguard Worker     }
68*4bdc9457SAndroid Build Coastguard Worker   }
69*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)70*4bdc9457SAndroid Build Coastguard Worker   inline PackMicrokernelTester& iterations(size_t iterations) {
71*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
72*4bdc9457SAndroid Build Coastguard Worker     return *this;
73*4bdc9457SAndroid Build Coastguard Worker   }
74*4bdc9457SAndroid Build Coastguard Worker 
iterations()75*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
76*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
77*4bdc9457SAndroid Build Coastguard Worker   }
78*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_x32_packx_ukernel_function packx)79*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_x32_packx_ukernel_function packx) const {
80*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
81*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
82*4bdc9457SAndroid Build Coastguard Worker     auto u32rng = std::bind(std::uniform_int_distribution<uint32_t>(), rng);
83*4bdc9457SAndroid Build Coastguard Worker 
84*4bdc9457SAndroid Build Coastguard Worker     const uint32_t c = u32rng();
85*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> x(k() + (m() - 1) * x_stride() + XNN_EXTRA_BYTES / sizeof(uint32_t));
86*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t, AlignedAllocator<uint32_t, 64>> y(mr() * k());
87*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> y_ref(mr() * k());
88*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
89*4bdc9457SAndroid Build Coastguard Worker       std::generate(x.begin(), x.end(), std::ref(u32rng));
90*4bdc9457SAndroid Build Coastguard Worker       std::generate(y.begin(), y.end(), std::ref(u32rng));
91*4bdc9457SAndroid Build Coastguard Worker 
92*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
93*4bdc9457SAndroid Build Coastguard Worker       std::fill(y_ref.begin(), y_ref.end(), c);
94*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < mr(); i++) {
95*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < k(); j++) {
96*4bdc9457SAndroid Build Coastguard Worker           y_ref[j * mr() + i] = x[std::min(i, m() - 1) * x_stride() + j];
97*4bdc9457SAndroid Build Coastguard Worker         }
98*4bdc9457SAndroid Build Coastguard Worker       }
99*4bdc9457SAndroid Build Coastguard Worker 
100*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
101*4bdc9457SAndroid Build Coastguard Worker       packx(
102*4bdc9457SAndroid Build Coastguard Worker         m(), k(),
103*4bdc9457SAndroid Build Coastguard Worker         x.data(), x_stride() * sizeof(uint32_t),
104*4bdc9457SAndroid Build Coastguard Worker         y.data());
105*4bdc9457SAndroid Build Coastguard Worker 
106*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
107*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < mr(); i++) {
108*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < k(); j++) {
109*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(y_ref[j * mr() + i], y[j * mr() + i])
110*4bdc9457SAndroid Build Coastguard Worker             << "at pixel = " << i << ", channel = " << j << ", "
111*4bdc9457SAndroid Build Coastguard Worker             << "m = " << m() << ", k = " << k();
112*4bdc9457SAndroid Build Coastguard Worker         }
113*4bdc9457SAndroid Build Coastguard Worker       }
114*4bdc9457SAndroid Build Coastguard Worker     }
115*4bdc9457SAndroid Build Coastguard Worker   }
116*4bdc9457SAndroid Build Coastguard Worker 
117*4bdc9457SAndroid Build Coastguard Worker  private:
118*4bdc9457SAndroid Build Coastguard Worker   size_t mr_{1};
119*4bdc9457SAndroid Build Coastguard Worker   size_t m_{1};
120*4bdc9457SAndroid Build Coastguard Worker   size_t k_{1};
121*4bdc9457SAndroid Build Coastguard Worker   size_t x_stride_{0};
122*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{1};
123*4bdc9457SAndroid Build Coastguard Worker };
124