xref: /aosp_15_r20/external/XNNPACK/test/prelu-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
12*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
15*4bdc9457SAndroid Build Coastguard Worker #include <random>
16*4bdc9457SAndroid Build Coastguard Worker #include <vector>
17*4bdc9457SAndroid Build Coastguard Worker 
18*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
19*4bdc9457SAndroid Build Coastguard Worker 
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
23*4bdc9457SAndroid Build Coastguard Worker 
24*4bdc9457SAndroid Build Coastguard Worker 
25*4bdc9457SAndroid Build Coastguard Worker class PReLUMicrokernelTester {
26*4bdc9457SAndroid Build Coastguard Worker  public:
rows(size_t rows)27*4bdc9457SAndroid Build Coastguard Worker   inline PReLUMicrokernelTester& rows(size_t rows) {
28*4bdc9457SAndroid Build Coastguard Worker     assert(rows != 0);
29*4bdc9457SAndroid Build Coastguard Worker     this->rows_ = rows;
30*4bdc9457SAndroid Build Coastguard Worker     return *this;
31*4bdc9457SAndroid Build Coastguard Worker   }
32*4bdc9457SAndroid Build Coastguard Worker 
rows()33*4bdc9457SAndroid Build Coastguard Worker   inline size_t rows() const {
34*4bdc9457SAndroid Build Coastguard Worker     return this->rows_;
35*4bdc9457SAndroid Build Coastguard Worker   }
36*4bdc9457SAndroid Build Coastguard Worker 
channels(size_t channels)37*4bdc9457SAndroid Build Coastguard Worker   inline PReLUMicrokernelTester& channels(size_t channels) {
38*4bdc9457SAndroid Build Coastguard Worker     assert(channels != 0);
39*4bdc9457SAndroid Build Coastguard Worker     this->channels_ = channels;
40*4bdc9457SAndroid Build Coastguard Worker     return *this;
41*4bdc9457SAndroid Build Coastguard Worker   }
42*4bdc9457SAndroid Build Coastguard Worker 
channels()43*4bdc9457SAndroid Build Coastguard Worker   inline size_t channels() const {
44*4bdc9457SAndroid Build Coastguard Worker     return this->channels_;
45*4bdc9457SAndroid Build Coastguard Worker   }
46*4bdc9457SAndroid Build Coastguard Worker 
input_stride(size_t input_stride)47*4bdc9457SAndroid Build Coastguard Worker   inline PReLUMicrokernelTester& input_stride(size_t input_stride) {
48*4bdc9457SAndroid Build Coastguard Worker     assert(input_stride != 0);
49*4bdc9457SAndroid Build Coastguard Worker     this->input_stride_ = input_stride;
50*4bdc9457SAndroid Build Coastguard Worker     return *this;
51*4bdc9457SAndroid Build Coastguard Worker   }
52*4bdc9457SAndroid Build Coastguard Worker 
input_stride()53*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_stride() const {
54*4bdc9457SAndroid Build Coastguard Worker     if (this->input_stride_ == 0) {
55*4bdc9457SAndroid Build Coastguard Worker       return channels();
56*4bdc9457SAndroid Build Coastguard Worker     } else {
57*4bdc9457SAndroid Build Coastguard Worker       assert(this->input_stride_ >= channels());
58*4bdc9457SAndroid Build Coastguard Worker       return this->input_stride_;
59*4bdc9457SAndroid Build Coastguard Worker     }
60*4bdc9457SAndroid Build Coastguard Worker   }
61*4bdc9457SAndroid Build Coastguard Worker 
output_stride(size_t output_stride)62*4bdc9457SAndroid Build Coastguard Worker   inline PReLUMicrokernelTester& output_stride(size_t output_stride) {
63*4bdc9457SAndroid Build Coastguard Worker     assert(output_stride != 0);
64*4bdc9457SAndroid Build Coastguard Worker     this->output_stride_ = output_stride;
65*4bdc9457SAndroid Build Coastguard Worker     return *this;
66*4bdc9457SAndroid Build Coastguard Worker   }
67*4bdc9457SAndroid Build Coastguard Worker 
output_stride()68*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_stride() const {
69*4bdc9457SAndroid Build Coastguard Worker     if (this->output_stride_ == 0) {
70*4bdc9457SAndroid Build Coastguard Worker       return channels();
71*4bdc9457SAndroid Build Coastguard Worker     } else {
72*4bdc9457SAndroid Build Coastguard Worker       assert(this->output_stride_ >= channels());
73*4bdc9457SAndroid Build Coastguard Worker       return this->output_stride_;
74*4bdc9457SAndroid Build Coastguard Worker     }
75*4bdc9457SAndroid Build Coastguard Worker   }
76*4bdc9457SAndroid Build Coastguard Worker 
inplace(bool inplace)77*4bdc9457SAndroid Build Coastguard Worker   inline PReLUMicrokernelTester& inplace(bool inplace) {
78*4bdc9457SAndroid Build Coastguard Worker     this->inplace_ = inplace;
79*4bdc9457SAndroid Build Coastguard Worker     return *this;
80*4bdc9457SAndroid Build Coastguard Worker   }
81*4bdc9457SAndroid Build Coastguard Worker 
inplace()82*4bdc9457SAndroid Build Coastguard Worker   inline bool inplace() const {
83*4bdc9457SAndroid Build Coastguard Worker     return this->inplace_;
84*4bdc9457SAndroid Build Coastguard Worker   }
85*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)86*4bdc9457SAndroid Build Coastguard Worker   inline PReLUMicrokernelTester& iterations(size_t iterations) {
87*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
88*4bdc9457SAndroid Build Coastguard Worker     return *this;
89*4bdc9457SAndroid Build Coastguard Worker   }
90*4bdc9457SAndroid Build Coastguard Worker 
iterations()91*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
92*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
93*4bdc9457SAndroid Build Coastguard Worker   }
94*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_prelu_ukernel_function prelu)95*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_prelu_ukernel_function prelu) const {
96*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
97*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
98*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f);
99*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> w32dist(0.25f, 0.75f);
100*4bdc9457SAndroid Build Coastguard Worker 
101*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x(channels() + (rows() - 1) * input_stride() + XNN_EXTRA_BYTES / sizeof(uint16_t));
102*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> w(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
103*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(channels() + (rows() - 1) * output_stride() + XNN_EXTRA_BYTES / sizeof(uint16_t));
104*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(channels() * rows());
105*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
106*4bdc9457SAndroid Build Coastguard Worker       std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
107*4bdc9457SAndroid Build Coastguard Worker       std::generate(w.begin(), w.end(), [&]() { return fp16_ieee_from_fp32_value(w32dist(rng)); });
108*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
109*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
110*4bdc9457SAndroid Build Coastguard Worker       } else {
111*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
112*4bdc9457SAndroid Build Coastguard Worker       }
113*4bdc9457SAndroid Build Coastguard Worker       const uint16_t* x_data = inplace() ? y.data() : x.data();
114*4bdc9457SAndroid Build Coastguard Worker 
115*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
116*4bdc9457SAndroid Build Coastguard Worker       for (size_t n = 0; n < rows(); n++) {
117*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
118*4bdc9457SAndroid Build Coastguard Worker           const float x_value = fp16_ieee_to_fp32_value(x_data[n * input_stride() + c]);
119*4bdc9457SAndroid Build Coastguard Worker           y_ref[n * channels() + c] = std::signbit(x_value) ?
120*4bdc9457SAndroid Build Coastguard Worker               fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(x_value * fp16_ieee_to_fp32_value(w[c]))) : x_value;
121*4bdc9457SAndroid Build Coastguard Worker         }
122*4bdc9457SAndroid Build Coastguard Worker       }
123*4bdc9457SAndroid Build Coastguard Worker 
124*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
125*4bdc9457SAndroid Build Coastguard Worker       prelu(rows(), channels() * sizeof(uint16_t),
126*4bdc9457SAndroid Build Coastguard Worker         x_data, input_stride() * sizeof(uint16_t),
127*4bdc9457SAndroid Build Coastguard Worker         w.data(),
128*4bdc9457SAndroid Build Coastguard Worker         y.data(), output_stride() * sizeof(uint16_t));
129*4bdc9457SAndroid Build Coastguard Worker 
130*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
131*4bdc9457SAndroid Build Coastguard Worker       for (size_t n = 0; n < rows(); n++) {
132*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
133*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(fp16_ieee_to_fp32_value(y[n * output_stride() + c]), y_ref[n * channels() + c])
134*4bdc9457SAndroid Build Coastguard Worker             << "at row " << n << " / " << rows()
135*4bdc9457SAndroid Build Coastguard Worker             << ", channel " << c << " / " << channels();
136*4bdc9457SAndroid Build Coastguard Worker         }
137*4bdc9457SAndroid Build Coastguard Worker       }
138*4bdc9457SAndroid Build Coastguard Worker     }
139*4bdc9457SAndroid Build Coastguard Worker   }
140*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_prelu_ukernel_function prelu)141*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_prelu_ukernel_function prelu) const {
142*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
143*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
144*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f);
145*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> w32dist(0.25f, 0.75f);
146*4bdc9457SAndroid Build Coastguard Worker 
147*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(channels() + (rows() - 1) * input_stride() + XNN_EXTRA_BYTES / sizeof(float));
148*4bdc9457SAndroid Build Coastguard Worker     std::vector<float, AlignedAllocator<float, 64>> w(channels() + XNN_EXTRA_BYTES / sizeof(float));
149*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(channels() + (rows() - 1) * output_stride() + XNN_EXTRA_BYTES / sizeof(float));
150*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(channels() * rows());
151*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
152*4bdc9457SAndroid Build Coastguard Worker       std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
153*4bdc9457SAndroid Build Coastguard Worker       std::generate(w.begin(), w.end(), [&]() { return w32dist(rng); });
154*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
155*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
156*4bdc9457SAndroid Build Coastguard Worker       } else {
157*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
158*4bdc9457SAndroid Build Coastguard Worker       }
159*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
160*4bdc9457SAndroid Build Coastguard Worker 
161*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
162*4bdc9457SAndroid Build Coastguard Worker       for (size_t n = 0; n < rows(); n++) {
163*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
164*4bdc9457SAndroid Build Coastguard Worker           const float x_value = x_data[n * input_stride() + c];
165*4bdc9457SAndroid Build Coastguard Worker           y_ref[n * channels() + c] = std::signbit(x_value) ? x_value * w[c] : x_value;
166*4bdc9457SAndroid Build Coastguard Worker         }
167*4bdc9457SAndroid Build Coastguard Worker       }
168*4bdc9457SAndroid Build Coastguard Worker 
169*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
170*4bdc9457SAndroid Build Coastguard Worker       prelu(rows(), channels() * sizeof(float),
171*4bdc9457SAndroid Build Coastguard Worker         x_data, input_stride() * sizeof(float),
172*4bdc9457SAndroid Build Coastguard Worker         w.data(),
173*4bdc9457SAndroid Build Coastguard Worker         y.data(), output_stride() * sizeof(float));
174*4bdc9457SAndroid Build Coastguard Worker 
175*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
176*4bdc9457SAndroid Build Coastguard Worker       for (size_t n = 0; n < rows(); n++) {
177*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
178*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(y[n * output_stride() + c], y_ref[n * channels() + c])
179*4bdc9457SAndroid Build Coastguard Worker             << "at row " << n << " / " << rows()
180*4bdc9457SAndroid Build Coastguard Worker             << ", channel " << c << " / " << channels();
181*4bdc9457SAndroid Build Coastguard Worker         }
182*4bdc9457SAndroid Build Coastguard Worker       }
183*4bdc9457SAndroid Build Coastguard Worker     }
184*4bdc9457SAndroid Build Coastguard Worker   }
185*4bdc9457SAndroid Build Coastguard Worker 
186*4bdc9457SAndroid Build Coastguard Worker  private:
187*4bdc9457SAndroid Build Coastguard Worker   size_t rows_{1};
188*4bdc9457SAndroid Build Coastguard Worker   size_t channels_{1};
189*4bdc9457SAndroid Build Coastguard Worker   size_t input_stride_{0};
190*4bdc9457SAndroid Build Coastguard Worker   size_t output_stride_{0};
191*4bdc9457SAndroid Build Coastguard Worker   bool inplace_{false};
192*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{15};
193*4bdc9457SAndroid Build Coastguard Worker };
194