1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
12*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
15*4bdc9457SAndroid Build Coastguard Worker #include <random>
16*4bdc9457SAndroid Build Coastguard Worker #include <vector>
17*4bdc9457SAndroid Build Coastguard Worker
18*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
19*4bdc9457SAndroid Build Coastguard Worker
20*4bdc9457SAndroid Build Coastguard Worker
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
25*4bdc9457SAndroid Build Coastguard Worker
26*4bdc9457SAndroid Build Coastguard Worker
27*4bdc9457SAndroid Build Coastguard Worker class GAvgPoolCWMicrokernelTester {
28*4bdc9457SAndroid Build Coastguard Worker public:
29*4bdc9457SAndroid Build Coastguard Worker enum class Variant {
30*4bdc9457SAndroid Build Coastguard Worker Native,
31*4bdc9457SAndroid Build Coastguard Worker Scalar,
32*4bdc9457SAndroid Build Coastguard Worker };
33*4bdc9457SAndroid Build Coastguard Worker
elements(size_t elements)34*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolCWMicrokernelTester& elements(size_t elements) {
35*4bdc9457SAndroid Build Coastguard Worker assert(elements != 0);
36*4bdc9457SAndroid Build Coastguard Worker this->elements_ = elements;
37*4bdc9457SAndroid Build Coastguard Worker return *this;
38*4bdc9457SAndroid Build Coastguard Worker }
39*4bdc9457SAndroid Build Coastguard Worker
elements()40*4bdc9457SAndroid Build Coastguard Worker inline size_t elements() const {
41*4bdc9457SAndroid Build Coastguard Worker return this->elements_;
42*4bdc9457SAndroid Build Coastguard Worker }
43*4bdc9457SAndroid Build Coastguard Worker
channels(size_t channels)44*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolCWMicrokernelTester& channels(size_t channels) {
45*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0);
46*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels;
47*4bdc9457SAndroid Build Coastguard Worker return *this;
48*4bdc9457SAndroid Build Coastguard Worker }
49*4bdc9457SAndroid Build Coastguard Worker
channels()50*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const {
51*4bdc9457SAndroid Build Coastguard Worker return this->channels_;
52*4bdc9457SAndroid Build Coastguard Worker }
53*4bdc9457SAndroid Build Coastguard Worker
qmin(uint8_t qmin)54*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolCWMicrokernelTester& qmin(uint8_t qmin) {
55*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin;
56*4bdc9457SAndroid Build Coastguard Worker return *this;
57*4bdc9457SAndroid Build Coastguard Worker }
58*4bdc9457SAndroid Build Coastguard Worker
qmin()59*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const {
60*4bdc9457SAndroid Build Coastguard Worker return this->qmin_;
61*4bdc9457SAndroid Build Coastguard Worker }
62*4bdc9457SAndroid Build Coastguard Worker
qmax(uint8_t qmax)63*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolCWMicrokernelTester& qmax(uint8_t qmax) {
64*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax;
65*4bdc9457SAndroid Build Coastguard Worker return *this;
66*4bdc9457SAndroid Build Coastguard Worker }
67*4bdc9457SAndroid Build Coastguard Worker
qmax()68*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const {
69*4bdc9457SAndroid Build Coastguard Worker return this->qmax_;
70*4bdc9457SAndroid Build Coastguard Worker }
71*4bdc9457SAndroid Build Coastguard Worker
iterations(size_t iterations)72*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolCWMicrokernelTester& iterations(size_t iterations) {
73*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations;
74*4bdc9457SAndroid Build Coastguard Worker return *this;
75*4bdc9457SAndroid Build Coastguard Worker }
76*4bdc9457SAndroid Build Coastguard Worker
iterations()77*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const {
78*4bdc9457SAndroid Build Coastguard Worker return this->iterations_;
79*4bdc9457SAndroid Build Coastguard Worker }
80*4bdc9457SAndroid Build Coastguard Worker
81*4bdc9457SAndroid Build Coastguard Worker
82*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_gavgpool_cw_ukernel_function gavgpool, Variant variant = Variant::Native) const {
83*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
84*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
85*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
86*4bdc9457SAndroid Build Coastguard Worker
87*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(elements() * channels() + XNN_EXTRA_BYTES / sizeof(float));
88*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(channels());
89*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(channels());
90*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
91*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
92*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), std::nanf(""));
93*4bdc9457SAndroid Build Coastguard Worker
94*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping.
95*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < channels(); i++) {
96*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f;
97*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < elements(); j++) {
98*4bdc9457SAndroid Build Coastguard Worker acc += x[i * elements() + j];
99*4bdc9457SAndroid Build Coastguard Worker }
100*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = acc / float(elements());
101*4bdc9457SAndroid Build Coastguard Worker }
102*4bdc9457SAndroid Build Coastguard Worker
103*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters.
104*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(y_ref.cbegin(), y_ref.cend());
105*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(y_ref.cbegin(), y_ref.cend());
106*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min;
107*4bdc9457SAndroid Build Coastguard Worker const float y_min = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
108*4bdc9457SAndroid Build Coastguard Worker const float y_max = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
109*4bdc9457SAndroid Build Coastguard Worker
110*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
111*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_gavgpool_params params;
112*4bdc9457SAndroid Build Coastguard Worker switch (variant) {
113*4bdc9457SAndroid Build Coastguard Worker case Variant::Native:
114*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_gavgpool_params(
115*4bdc9457SAndroid Build Coastguard Worker ¶ms, 1.0f / float(elements()), y_min, y_max, elements());
116*4bdc9457SAndroid Build Coastguard Worker break;
117*4bdc9457SAndroid Build Coastguard Worker case Variant::Scalar:
118*4bdc9457SAndroid Build Coastguard Worker xnn_init_scalar_f32_gavgpool_params(
119*4bdc9457SAndroid Build Coastguard Worker ¶ms, 1.0f / float(elements()), y_min, y_max, elements());
120*4bdc9457SAndroid Build Coastguard Worker break;
121*4bdc9457SAndroid Build Coastguard Worker }
122*4bdc9457SAndroid Build Coastguard Worker
123*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results.
124*4bdc9457SAndroid Build Coastguard Worker for (float& y_value : y_ref) {
125*4bdc9457SAndroid Build Coastguard Worker y_value = std::max(std::min(y_value, y_max), y_min);
126*4bdc9457SAndroid Build Coastguard Worker }
127*4bdc9457SAndroid Build Coastguard Worker
128*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel.
129*4bdc9457SAndroid Build Coastguard Worker gavgpool(elements() * sizeof(float), channels(), x.data(), y.data(), ¶ms);
130*4bdc9457SAndroid Build Coastguard Worker
131*4bdc9457SAndroid Build Coastguard Worker // Verify results.
132*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < channels(); i++) {
133*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(y[i], y_max)
134*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << ", elements = " << elements() << ", channels = " << channels();
135*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(y[i], y_min)
136*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << ", elements = " << elements() << ", channels = " << channels();
137*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y[i], y_ref[i], std::abs(y_ref[i]) * 1.0e-6f)
138*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << ", elements = " << elements() << ", channels = " << channels();
139*4bdc9457SAndroid Build Coastguard Worker }
140*4bdc9457SAndroid Build Coastguard Worker }
141*4bdc9457SAndroid Build Coastguard Worker }
142*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f16_gavgpool_cw_ukernel_function gavgpool,xnn_init_f16_gavgpool_neonfp16arith_params_fn init_params)143*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_gavgpool_cw_ukernel_function gavgpool, xnn_init_f16_gavgpool_neonfp16arith_params_fn init_params) const {
144*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
145*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
146*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 10.0f);
147*4bdc9457SAndroid Build Coastguard Worker
148*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x(elements() * channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
149*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(channels());
150*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(channels());
151*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
152*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
153*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
154*4bdc9457SAndroid Build Coastguard Worker
155*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping.
156*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < channels(); i++) {
157*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f;
158*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < elements(); j++) {
159*4bdc9457SAndroid Build Coastguard Worker acc += fp16_ieee_to_fp32_value(x[i * elements() + j]);
160*4bdc9457SAndroid Build Coastguard Worker }
161*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = acc / float(elements());
162*4bdc9457SAndroid Build Coastguard Worker }
163*4bdc9457SAndroid Build Coastguard Worker
164*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters.
165*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(y_ref.cbegin(), y_ref.cend());
166*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(y_ref.cbegin(), y_ref.cend());
167*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min;
168*4bdc9457SAndroid Build Coastguard Worker const float y_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
169*4bdc9457SAndroid Build Coastguard Worker const float y_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
170*4bdc9457SAndroid Build Coastguard Worker
171*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
172*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_gavgpool_params params;
173*4bdc9457SAndroid Build Coastguard Worker init_params(
174*4bdc9457SAndroid Build Coastguard Worker ¶ms, fp16_ieee_from_fp32_value(1.0f / float(elements())), fp16_ieee_from_fp32_value(y_min), fp16_ieee_from_fp32_value(y_max), elements());
175*4bdc9457SAndroid Build Coastguard Worker
176*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results.
177*4bdc9457SAndroid Build Coastguard Worker for (float& y_value : y_ref) {
178*4bdc9457SAndroid Build Coastguard Worker y_value = std::max(std::min(y_value, y_max), y_min);
179*4bdc9457SAndroid Build Coastguard Worker }
180*4bdc9457SAndroid Build Coastguard Worker
181*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel.
182*4bdc9457SAndroid Build Coastguard Worker gavgpool(elements() * sizeof(uint16_t), channels(), x.data(), y.data(), ¶ms);
183*4bdc9457SAndroid Build Coastguard Worker
184*4bdc9457SAndroid Build Coastguard Worker // Verify results.
185*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < channels(); i++) {
186*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(y[i]), y_max)
187*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << ", elements = " << elements() << ", channels = " << channels();
188*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(y[i]), y_min)
189*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << ", elements = " << elements() << ", channels = " << channels();
190*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(y[i]), y_ref[i], 1.0e-2f * std::abs(y_ref[i]))
191*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << ", elements = " << elements() << ", channels = " << channels();
192*4bdc9457SAndroid Build Coastguard Worker }
193*4bdc9457SAndroid Build Coastguard Worker }
194*4bdc9457SAndroid Build Coastguard Worker }
195*4bdc9457SAndroid Build Coastguard Worker
196*4bdc9457SAndroid Build Coastguard Worker private:
197*4bdc9457SAndroid Build Coastguard Worker size_t elements_{1};
198*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1};
199*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0};
200*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255};
201*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15};
202*4bdc9457SAndroid Build Coastguard Worker };
203